mistralrs_server_core/
mistralrs_for_server_builder.rs

1//! ## mistral.rs instance for server builder.
2
3use std::sync::Arc;
4
5use anyhow::{Context, Result};
6use candle_core::Device;
7use mistralrs_core::{
8    get_auto_device_map_params, get_model_dtype, get_tgt_non_granular_index, paged_attn_supported,
9    parse_isq_value, AutoDeviceMapParams, BertEmbeddingModel, DefaultSchedulerMethod,
10    DeviceLayerMapMetadata, DeviceMapMetadata, DeviceMapSetting, Loader, LoaderBuilder,
11    McpClientConfig, MemoryGpuConfig, MistralRsBuilder, ModelSelected, PagedAttentionConfig,
12    PagedCacheType, SchedulerConfig, SearchCallback, TokenSource,
13};
14use tracing::{info, warn};
15
16use crate::types::{LoadedPipeline, SharedMistralRsState};
17use std::collections::HashMap;
18
19/// Configuration for a single model in a multi-model setup
20#[derive(Clone, serde::Deserialize)]
21pub struct ModelConfig {
22    /// Model identifier (used in API requests)
23    pub model_id: String,
24    /// Model selector
25    pub model: ModelSelected,
26    /// Model-specific chat template
27    pub chat_template: Option<String>,
28    /// Model-specific JINJA template
29    pub jinja_explicit: Option<String>,
30    /// Model-specific device layers
31    pub num_device_layers: Option<Vec<String>>,
32    /// Model-specific in-situ quantization
33    pub in_situ_quant: Option<String>,
34}
35
36impl ModelConfig {
37    pub fn new(model_id: String, model: ModelSelected) -> Self {
38        Self {
39            model_id,
40            model,
41            chat_template: None,
42            jinja_explicit: None,
43            num_device_layers: None,
44            in_situ_quant: None,
45        }
46    }
47
48    pub fn with_chat_template(mut self, chat_template: String) -> Self {
49        self.chat_template = Some(chat_template);
50        self
51    }
52
53    pub fn with_jinja_explicit(mut self, jinja_explicit: String) -> Self {
54        self.jinja_explicit = Some(jinja_explicit);
55        self
56    }
57
58    pub fn with_num_device_layers(mut self, num_device_layers: Vec<String>) -> Self {
59        self.num_device_layers = Some(num_device_layers);
60        self
61    }
62
63    pub fn with_in_situ_quant(mut self, in_situ_quant: String) -> Self {
64        self.in_situ_quant = Some(in_situ_quant);
65        self
66    }
67}
68
69pub mod defaults {
70    //! Provides the default values used for the mistral.rs instance for server.
71    //! These defaults can be used for CLI argument fallbacks, config loading, or general initialization.
72
73    use std::sync::Arc;
74
75    use mistralrs_core::PagedCacheType;
76
77    pub const DEVICE: Option<candle_core::Device> = None;
78    pub const SEED: Option<u64> = None;
79    pub const LOG: Option<String> = None;
80    pub const TRUNCATE_SEQUENCE: bool = false;
81    pub const MODEL: Option<mistralrs_core::ModelSelected> = None;
82    pub const MAX_SEQS: usize = 16;
83    pub const NO_KV_CACHE: bool = false;
84    pub const CHAT_TEMPLATE: Option<String> = None;
85    pub const JINJA_EXPLICIT: Option<String> = None;
86    pub const INTERACTIVE_MODE: bool = false;
87    pub const PREFIX_CACHE_N: usize = 16;
88    pub const NUM_DEVICE_LAYERS: Option<Vec<String>> = None;
89    pub const IN_SITU_QUANT: Option<String> = None;
90    pub const PAGED_ATTN_GPU_MEM: Option<usize> = None;
91    pub const PAGED_ATTN_GPU_MEM_USAGE: Option<f32> = None;
92    pub const PAGED_CTXT_LEN: Option<usize> = None;
93    pub const PAGED_ATTN_BLOCK_SIZE: Option<usize> = None;
94    pub const PAGED_ATTN: Option<bool> = None;
95    pub const PAGED_ATTN_CPU: bool = false;
96    pub const PAGED_ATTN_CUDA: bool = true;
97    pub const PAGED_ATTN_METAL: bool = false;
98    pub const CPU: bool = false;
99    pub const ENABLE_SEARCH: bool = false;
100    pub const SEARCH_BERT_MODEL: Option<String> = None;
101    pub const TOKEN_SOURCE: mistralrs_core::TokenSource = mistralrs_core::TokenSource::CacheToken;
102    pub const SEARCH_CALLBACK: Option<Arc<mistralrs_core::SearchCallback>> = None;
103    pub const PAGED_CACHE_TYPE: PagedCacheType = PagedCacheType::Auto;
104}
105
106/// A builder for creating a mistral.rs instance with configured options for the mistral.rs server.
107///
108/// ### Examples
109///
110/// Basic usage:
111/// ```ignore
112/// use mistralrs_server_core::mistralrs_for_server_builder::MistralRsForServerBuilder;
113///
114/// let args = Args::parse();
115///
116/// let mistralrs = MistralRsForServerBuilder::new()
117///        .with_truncate_sequence(args.truncate_sequence)
118///        .with_model(args.model)
119///        .with_max_seqs(args.max_seqs)
120///        .with_no_kv_cache(args.no_kv_cache)
121///        .with_token_source(args.token_source)
122///        .with_interactive_mode(args.interactive_mode)
123///        .with_prefix_cache_n(args.prefix_cache_n)
124///        .with_paged_attn(args.paged_attn)
125///        .with_cpu(args.cpu)
126///        .with_enable_search(args.enable_search)
127///        .with_seed_optional(args.seed)
128///        .with_log_optional(args.log)
129///        .with_chat_template_optional(args.chat_template)
130///        .with_jinja_explicit_optional(args.jinja_explicit)
131///        .with_num_device_layers_optional(args.num_device_layers)
132///        .with_in_situ_quant_optional(args.in_situ_quant)
133///        .with_paged_attn_gpu_mem_optional(args.paged_attn_gpu_mem)
134///        .with_paged_attn_gpu_mem_usage_optional(args.paged_attn_gpu_mem_usage)
135///        .with_paged_ctxt_len_optional(args.paged_ctxt_len)
136///        .with_paged_attn_block_size_optional(args.paged_attn_block_size)
137///        .build()
138///        .await?;
139/// ```
140pub struct MistralRsForServerBuilder {
141    /// The Candle device to use for model execution (CPU, CUDA, Metal, etc.).
142    device: Option<Device>,
143
144    /// Integer seed to ensure reproducible random number generation.
145    seed: Option<u64>,
146
147    /// Log all responses and requests to this file
148    log: Option<String>,
149
150    /// If a sequence is larger than the maximum model length, truncate the number
151    /// of tokens such that the sequence will fit at most the maximum length.
152    /// If `max_tokens` is not specified in the request, space for 10 tokens will be reserved instead.
153    truncate_sequence: bool,
154
155    /// Model selector (for single-model mode, deprecated in favor of models)
156    model: Option<ModelSelected>,
157
158    /// Multiple model configurations (for multi-model mode)
159    models: Vec<ModelConfig>,
160
161    /// Default model ID to use when none is specified in requests
162    default_model_id: Option<String>,
163
164    /// Maximum running sequences at any time. If the `tgt_non_granular_index` flag is set for X-LoRA models, this will be set to 1.
165    max_seqs: usize,
166
167    /// Use no KV cache.
168    no_kv_cache: bool,
169
170    /// Chat template file with a JINJA file with `messages`, `add_generation_prompt`, `bos_token`, `eos_token`, and `unk_token` as inputs.
171    /// Used if the automatic deserialization fails. If this ends with `.json` (ie., it is a file) then that template is loaded.
172    chat_template: Option<String>,
173
174    /// Explicit JINJA chat template file (.jinja) to be used. If specified, this overrides all other chat templates.
175    jinja_explicit: Option<String>,
176
177    /// Source of the token for authentication.
178    /// Can be in the formats: `literal:<value>`, `env:<value>`, `path:<value>`, `cache` to use a cached token, or `none` to use no token.
179    /// Defaults to `cache`.
180    token_source: TokenSource,
181
182    /// Enter interactive mode instead of serving a chat server.
183    interactive_mode: bool,
184
185    /// Number of prefix caches to hold on the device. Other caches are evicted to the CPU based on a LRU strategy.
186    prefix_cache_n: usize,
187
188    /// NOTE: This can be omitted to use automatic device mapping!
189    /// Number of device layers to load and run on GPU(s). All others will be on the CPU.
190    /// If one GPU is used, then this value should be an integer. Otherwise, it follows the following pattern:
191    /// ORD:NUM;... Where ORD is a unique device ordinal and NUM is the number of layers for that device.
192    num_device_layers: Option<Vec<String>>,
193
194    /// In-situ quantization to apply.
195    in_situ_quant: Option<String>,
196
197    /// GPU memory to allocate for KV cache with PagedAttention in MBs.
198    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
199    /// The priority is as follows: `pa-ctxt-len` > `pa-gpu-mem-usage` > `pa-gpu-mem`.
200    paged_attn_gpu_mem: Option<usize>,
201
202    /// Percentage of GPU memory to utilize after allocation of KV cache with PagedAttention, from 0 to 1.
203    /// If this is not set and the device is CUDA, it will default to `0.9`.
204    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
205    /// The priority is as follows: `pa-ctxt-len` > `pa-gpu-mem-usage` > `pa-gpu-mem`.
206    paged_attn_gpu_mem_usage: Option<f32>,
207
208    /// Total context length to allocate the KV cache for (total number of tokens which the KV cache can hold).
209    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
210    /// The priority is as follows: `pa-ctxt-len` > `pa-gpu-mem-usage` > `pa-gpu-mem`.
211    /// This is the default setting, and it defaults to the `max-seq-len` specified in after the model type.
212    paged_ctxt_len: Option<usize>,
213
214    /// Block size (number of tokens per block) for PagedAttention. If this is not set and the device is CUDA, it will default to 32.
215    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
216    paged_attn_block_size: Option<usize>,
217
218    /// Enables or disables PagedAttention. By default, PagedAttention will be enabled for CUDA and disabled for Metal (and is not supported for CPU). Use this to override the default behavior.
219    paged_attn: Option<bool>,
220
221    /// Use CPU only
222    cpu: bool,
223
224    /// Enable searching compatible with the OpenAI `web_search_options` setting. This uses the BERT model specified below or the default.
225    enable_search: bool,
226
227    /// Specify a Hugging Face model ID for a BERT model to assist web searching. Defaults to Snowflake Arctic Embed L.
228    search_bert_model: Option<String>,
229
230    /// Optional override search callback
231    search_callback: Option<Arc<SearchCallback>>,
232
233    /// Optional MCP client configuration
234    mcp_client_config: Option<McpClientConfig>,
235
236    /// PagedAttention KV cache type
237    paged_cache_type: PagedCacheType,
238}
239
240impl Default for MistralRsForServerBuilder {
241    /// Creates a new builder with default configuration.
242    fn default() -> Self {
243        Self {
244            device: defaults::DEVICE,
245            seed: defaults::SEED,
246            log: defaults::LOG,
247            truncate_sequence: defaults::TRUNCATE_SEQUENCE,
248            model: defaults::MODEL,
249            models: Vec::new(),
250            default_model_id: None,
251            max_seqs: defaults::MAX_SEQS,
252            no_kv_cache: defaults::NO_KV_CACHE,
253            chat_template: defaults::CHAT_TEMPLATE,
254            jinja_explicit: defaults::JINJA_EXPLICIT,
255            token_source: defaults::TOKEN_SOURCE,
256            interactive_mode: defaults::INTERACTIVE_MODE,
257            prefix_cache_n: defaults::PREFIX_CACHE_N,
258            num_device_layers: defaults::NUM_DEVICE_LAYERS,
259            in_situ_quant: defaults::IN_SITU_QUANT,
260            paged_attn_gpu_mem: defaults::PAGED_ATTN_GPU_MEM,
261            paged_attn_gpu_mem_usage: defaults::PAGED_ATTN_GPU_MEM_USAGE,
262            paged_ctxt_len: defaults::PAGED_CTXT_LEN,
263            paged_attn_block_size: defaults::PAGED_ATTN_BLOCK_SIZE,
264            paged_attn: defaults::PAGED_ATTN,
265            cpu: defaults::CPU,
266            enable_search: defaults::ENABLE_SEARCH,
267            search_bert_model: defaults::SEARCH_BERT_MODEL,
268            search_callback: defaults::SEARCH_CALLBACK,
269            mcp_client_config: None,
270            paged_cache_type: defaults::PAGED_CACHE_TYPE,
271        }
272    }
273}
274
275impl MistralRsForServerBuilder {
276    /// Creates a new `MistralRsForServerBuilder` with default settings.
277    ///
278    /// This is equivalent to calling `Default::default()`.
279    ///
280    /// ### Examples
281    ///
282    /// ```ignore
283    /// use mistralrs_server_core::mistralrs_for_server_builder::MistralRsForServerBuilder;
284    ///
285    /// let builder = mistralrs_server_core::mistralrs_for_server_builder::MistralRsForServerBuilder::new();
286    /// ```
287    pub fn new() -> Self {
288        Default::default()
289    }
290
291    /// Sets the Candle device to use for model execution.
292    pub fn with_device(mut self, device: Device) -> Self {
293        self.device = Some(device);
294        self
295    }
296
297    /// Sets the random seed for deterministic model behavior.
298    pub fn with_seed(mut self, seed: u64) -> Self {
299        self.seed = Some(seed);
300        self
301    }
302
303    /// Sets the random seed if provided.
304    pub fn with_seed_optional(mut self, seed: Option<u64>) -> Self {
305        if let Some(seed) = seed {
306            self = self.with_seed(seed);
307        }
308        self
309    }
310
311    /// Sets the logging configuration.
312    pub fn with_log(mut self, log: String) -> Self {
313        self.log = Some(log);
314        self
315    }
316
317    /// Sets the logging configuration if provided.
318    pub fn with_log_optional(mut self, log: Option<String>) -> Self {
319        if let Some(log) = log {
320            self = self.with_log(log);
321        }
322        self
323    }
324
325    /// Sets whether to truncate sequences that exceed the maximum model length.
326    pub fn with_truncate_sequence(mut self, truncate_sequence: bool) -> Self {
327        self.truncate_sequence = truncate_sequence;
328        self
329    }
330
331    /// Sets the model to be used.
332    pub fn with_model(mut self, model: ModelSelected) -> Self {
333        self.model = Some(model);
334        self
335    }
336
337    /// Add a model to the multi-model configuration.
338    pub fn with_model_config(mut self, model_config: ModelConfig) -> Self {
339        self.models.push(model_config);
340        self
341    }
342
343    /// Add multiple models to the multi-model configuration.
344    pub fn with_model_configs(mut self, model_configs: Vec<ModelConfig>) -> Self {
345        self.models.extend(model_configs);
346        self
347    }
348
349    /// Set the default model ID to use when none is specified in requests.
350    pub fn with_default_model_id(mut self, default_model_id: String) -> Self {
351        self.default_model_id = Some(default_model_id);
352        self
353    }
354
355    /// Add a model configuration.
356    pub fn add_model_config(mut self, config: ModelConfig) -> Self {
357        self.models.push(config);
358        self
359    }
360
361    /// Add a model with just an ID and ModelSelected (convenience method).
362    pub fn add_model(mut self, model_id: String, model: ModelSelected) -> Self {
363        self.models.push(ModelConfig::new(model_id, model));
364        self
365    }
366
367    /// Sets the maximum number of concurrent sequences.
368    pub fn with_max_seqs(mut self, max_seqs: usize) -> Self {
369        self.max_seqs = max_seqs;
370        self
371    }
372
373    /// Sets whether to disable the key-value cache.
374    pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
375        self.no_kv_cache = no_kv_cache;
376        self
377    }
378
379    /// Sets the chat template configuration.
380    pub fn with_chat_template(mut self, chat_template: String) -> Self {
381        self.chat_template = Some(chat_template);
382        self
383    }
384
385    /// Sets the chat template configuration if provided.
386    pub fn with_chat_template_optional(mut self, chat_template: Option<String>) -> Self {
387        if let Some(chat_template) = chat_template {
388            self = self.with_chat_template(chat_template);
389        }
390        self
391    }
392
393    /// Sets an explicit JINJA chat template file.
394    pub fn with_jinja_explicit(mut self, jinja_explicit: String) -> Self {
395        self.jinja_explicit = Some(jinja_explicit);
396        self
397    }
398
399    /// Sets an explicit JINJA chat template file if provided.
400    pub fn with_jinja_explicit_optional(mut self, jinja_explicit: Option<String>) -> Self {
401        if let Some(jinja_explicit) = jinja_explicit {
402            self = self.with_jinja_explicit(jinja_explicit);
403        }
404        self
405    }
406
407    /// Sets the token source for authentication.
408    pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
409        self.token_source = token_source;
410        self
411    }
412
413    /// Sets whether to run in interactive mode.
414    pub fn with_interactive_mode(mut self, interactive_mode: bool) -> Self {
415        self.interactive_mode = interactive_mode;
416        self
417    }
418
419    /// Sets the number of prefix caches to hold on the device.
420    pub fn with_prefix_cache_n(mut self, prefix_cache_n: usize) -> Self {
421        self.prefix_cache_n = prefix_cache_n;
422        self
423    }
424
425    /// Sets the device layer mapping
426    pub fn with_num_device_layers(mut self, num_device_layers: Vec<String>) -> Self {
427        self.num_device_layers = Some(num_device_layers);
428        self
429    }
430
431    /// Sets the device layer mapping if provided.
432    pub fn with_num_device_layers_optional(
433        mut self,
434        num_device_layers: Option<Vec<String>>,
435    ) -> Self {
436        if let Some(num_device_layers) = num_device_layers {
437            self = self.with_num_device_layers(num_device_layers);
438        }
439        self
440    }
441
442    /// Sets the in-situ quantization method.
443    pub fn with_in_situ_quant(mut self, in_situ_quant: String) -> Self {
444        self.in_situ_quant = Some(in_situ_quant);
445        self
446    }
447
448    /// Sets the in-situ quantization method if provided.
449    pub fn with_in_situ_quant_optional(mut self, in_situ_quant: Option<String>) -> Self {
450        if let Some(in_situ_quant) = in_situ_quant {
451            self = self.with_in_situ_quant(in_situ_quant);
452        }
453        self
454    }
455
456    /// Sets PagedAttention.
457    ///
458    /// Unlike other `with_PROP` or `with_PROP_optional` methods, this method
459    /// sets the value to whatever `Option<bool>` is passed in as `None`, `Some(true)`
460    /// and `Some(false)` have different implications.
461    ///
462    /// `None`: default behavior for target device (e.g. enable for CUDA, disable for Metal)
463    /// `Some(true)`: enable (if supported by target device)
464    /// `Some(false)`: disable
465    pub fn set_paged_attn(mut self, paged_attn: Option<bool>) -> Self {
466        self.paged_attn = paged_attn;
467        self
468    }
469
470    /// Sets the GPU memory allocation for PagedAttention KV cache.
471    pub fn with_paged_attn_gpu_mem(mut self, paged_attn_gpu_mem: usize) -> Self {
472        self.paged_attn_gpu_mem = Some(paged_attn_gpu_mem);
473        self
474    }
475
476    /// Sets the GPU memory allocation for PagedAttention KV cache if provided.
477    pub fn with_paged_attn_gpu_mem_optional(mut self, paged_attn_gpu_mem: Option<usize>) -> Self {
478        if let Some(paged_attn_gpu_mem) = paged_attn_gpu_mem {
479            self = self.with_paged_attn_gpu_mem(paged_attn_gpu_mem);
480        }
481        self
482    }
483
484    /// Sets the percentage of GPU memory to utilize for PagedAttention.
485    pub fn with_paged_attn_gpu_mem_usage(mut self, paged_attn_gpu_mem_usage: f32) -> Self {
486        self.paged_attn_gpu_mem_usage = Some(paged_attn_gpu_mem_usage);
487        self
488    }
489
490    /// Sets the percentage of GPU memory to utilize for PagedAttention if provided.
491    pub fn with_paged_attn_gpu_mem_usage_optional(
492        mut self,
493        paged_attn_gpu_mem_usage: Option<f32>,
494    ) -> Self {
495        if let Some(paged_attn_gpu_mem_usage) = paged_attn_gpu_mem_usage {
496            self = self.with_paged_attn_gpu_mem_usage(paged_attn_gpu_mem_usage);
497        }
498        self
499    }
500
501    /// Sets the total context length for KV cache allocation.
502    pub fn with_paged_ctxt_len(mut self, paged_ctxt_len: usize) -> Self {
503        self.paged_ctxt_len = Some(paged_ctxt_len);
504        self
505    }
506
507    /// Sets the total context length for KV cache allocation if provided.
508    pub fn with_paged_ctxt_len_optional(mut self, paged_ctxt_len: Option<usize>) -> Self {
509        if let Some(paged_ctxt_len) = paged_ctxt_len {
510            self = self.with_paged_ctxt_len(paged_ctxt_len);
511        }
512        self
513    }
514
515    /// Sets the block size for PagedAttention.
516    pub fn with_paged_attn_block_size(mut self, paged_attn_block_size: usize) -> Self {
517        self.paged_attn_block_size = Some(paged_attn_block_size);
518        self
519    }
520
521    /// Sets the block size for PagedAttention.
522    pub fn with_paged_attn_cache_type(mut self, cache_type: PagedCacheType) -> Self {
523        self.paged_cache_type = cache_type;
524        self
525    }
526
527    /// Sets the block size for PagedAttention if provided.
528    pub fn with_paged_attn_block_size_optional(
529        mut self,
530        paged_attn_block_size: Option<usize>,
531    ) -> Self {
532        if let Some(paged_attn_block_size) = paged_attn_block_size {
533            self = self.with_paged_attn_block_size(paged_attn_block_size);
534        }
535        self
536    }
537
538    /// Sets whether to force CPU-only execution.
539    pub fn with_cpu(mut self, cpu: bool) -> Self {
540        self.cpu = cpu;
541        self
542    }
543
544    /// Sets whether to enable web search functionality.
545    pub fn with_enable_search(mut self, enable_search: bool) -> Self {
546        self.enable_search = enable_search;
547        self
548    }
549
550    /// Sets the BERT model for web search assistance.
551    pub fn with_search_bert_model(mut self, search_bert_model: String) -> Self {
552        self.search_bert_model = Some(search_bert_model);
553        self
554    }
555
556    /// Override the search function used when `web_search_options` is enabled.
557    pub fn with_search_callback(mut self, callback: Arc<SearchCallback>) -> Self {
558        self.search_callback = Some(callback);
559        self
560    }
561
562    /// Sets the MCP client configuration.
563    pub fn with_mcp_config(mut self, mcp_config: McpClientConfig) -> Self {
564        self.mcp_client_config = Some(mcp_config);
565        self
566    }
567
568    /// Sets the MCP client configuration if provided.
569    pub fn with_mcp_config_optional(mut self, mcp_config: Option<McpClientConfig>) -> Self {
570        if let Some(mcp_config) = mcp_config {
571            self = self.with_mcp_config(mcp_config);
572        }
573        self
574    }
575
576    /// Builds the configured mistral.rs instance.
577    ///
578    /// ### Examples
579    ///
580    /// ```ignore
581    /// use mistralrs_server_core::mistralrs_for_server_builder::MistralRsForServerBuilder;
582    ///
583    /// let shared_mistralrs = MistralRsForServerBuilder::new()
584    ///     .with_model(model)
585    ///     .with_in_situ_quant("8".to_string())
586    ///     .set_paged_attn(Some(true))
587    ///     .build()
588    ///     .await?;
589    /// ```
590    pub async fn build(self) -> Result<SharedMistralRsState> {
591        // Determine if we're in single-model or multi-model mode
592        if !self.models.is_empty() {
593            self.build_multi_model().await
594        } else {
595            self.build_single_model().await
596        }
597    }
598
599    /// Build a single-model instance (legacy mode)
600    async fn build_single_model(mut self) -> Result<SharedMistralRsState> {
601        let model = self.model.context("Model was None")?;
602
603        let tgt_non_granular_index = get_tgt_non_granular_index(&model);
604        let dtype = get_model_dtype(&model)?;
605        let auto_device_map_params = get_auto_device_map_params(&model)?;
606
607        if tgt_non_granular_index.is_some() {
608            self.max_seqs = 1;
609        }
610
611        let max_seq_len = auto_device_map_params.max_seq_len();
612
613        let device = if let Some(device) = self.device {
614            device
615        } else {
616            init_device(self.cpu, self.seed)?
617        };
618
619        let mapper = init_mapper(&self.num_device_layers, &auto_device_map_params);
620        let paged_attn = configure_paged_attn(&device, self.paged_attn);
621
622        let cache_config = init_cache_config(
623            self.paged_attn_block_size,
624            self.paged_attn_gpu_mem,
625            self.paged_attn_gpu_mem_usage,
626            self.paged_ctxt_len,
627            self.paged_cache_type,
628            !paged_attn,
629            max_seq_len,
630        )?;
631
632        // Configure this last to prevent arg moves
633        let loader: Box<dyn Loader> = LoaderBuilder::new(model)
634            .with_no_kv_cache(self.no_kv_cache)
635            .with_chat_template(self.chat_template)
636            .with_jinja_explicit(self.jinja_explicit)
637            .build()?;
638
639        mistralrs_instance_info(&*loader);
640
641        let isq = self
642            .in_situ_quant
643            .as_ref()
644            .and_then(|isq| parse_isq_value(isq, Some(&device)).ok());
645
646        let pipeline: LoadedPipeline = loader.load_model_from_hf(
647            None,
648            self.token_source,
649            &dtype,
650            &device,
651            false,
652            mapper,
653            isq,
654            cache_config,
655        )?;
656        info!("Model loaded.");
657
658        let scheduler_config = init_scheduler_config(&cache_config, &pipeline, self.max_seqs).await;
659
660        let bert_model = get_bert_model(self.enable_search, self.search_bert_model);
661
662        let mut builder = MistralRsBuilder::new(
663            pipeline,
664            scheduler_config,
665            !self.interactive_mode,
666            bert_model,
667        )
668        .with_opt_log(self.log)
669        .with_truncate_sequence(self.truncate_sequence)
670        .with_no_kv_cache(self.no_kv_cache)
671        .with_prefix_cache_n(self.prefix_cache_n);
672
673        // Add MCP client configuration if provided
674        if let Some(mcp_config) = self.mcp_client_config {
675            builder = builder.with_mcp_client(mcp_config);
676        }
677
678        let mistralrs = builder.build().await;
679
680        Ok(mistralrs)
681    }
682
683    /// Build a multi-model instance
684    pub async fn build_multi_model(mut self) -> Result<SharedMistralRsState> {
685        if self.models.is_empty() {
686            anyhow::bail!("No models configured for multi-model mode");
687        }
688
689        // Use the first model as the base configuration
690        let first_model = &self.models[0];
691        let model = first_model.model.clone();
692
693        let tgt_non_granular_index = get_tgt_non_granular_index(&model);
694        let dtype = get_model_dtype(&model)?;
695        let auto_device_map_params = get_auto_device_map_params(&model)?;
696
697        if tgt_non_granular_index.is_some() {
698            self.max_seqs = 1;
699        }
700
701        let max_seq_len = auto_device_map_params.max_seq_len();
702
703        let device = if let Some(device) = self.device {
704            device
705        } else {
706            init_device(self.cpu, self.seed)?
707        };
708
709        // Create the first model's pipeline
710        let loader: Box<dyn Loader> = LoaderBuilder::new(model)
711            .with_no_kv_cache(self.no_kv_cache)
712            .with_chat_template(
713                first_model
714                    .chat_template
715                    .clone()
716                    .or(self.chat_template.clone()),
717            )
718            .with_jinja_explicit(
719                first_model
720                    .jinja_explicit
721                    .clone()
722                    .or(self.jinja_explicit.clone()),
723            )
724            .build()?;
725
726        mistralrs_instance_info(&*loader);
727
728        let mapper = init_mapper(
729            &first_model
730                .num_device_layers
731                .clone()
732                .or(self.num_device_layers.clone()),
733            &auto_device_map_params,
734        );
735        let paged_attn = configure_paged_attn(&device, self.paged_attn);
736
737        let cache_config = init_cache_config(
738            self.paged_attn_block_size,
739            self.paged_attn_gpu_mem,
740            self.paged_attn_gpu_mem_usage,
741            self.paged_ctxt_len,
742            self.paged_cache_type,
743            !paged_attn,
744            max_seq_len,
745        )?;
746
747        let isq = first_model
748            .in_situ_quant
749            .as_ref()
750            .or(self.in_situ_quant.as_ref())
751            .and_then(|isq| parse_isq_value(isq, Some(&device)).ok());
752
753        let mut pipeline_names = Vec::new();
754
755        let pipeline: LoadedPipeline = loader.load_model_from_hf(
756            None,
757            self.token_source.clone(),
758            &dtype,
759            &device,
760            false,
761            mapper,
762            isq,
763            cache_config,
764        )?;
765        let first_pipeline_name = pipeline.lock().await.name();
766        info!(
767            "First model loaded: `{first_pipeline_name}` (from config key: {})",
768            first_model.model_id
769        );
770        pipeline_names.push(first_pipeline_name);
771
772        let scheduler_config = init_scheduler_config(&cache_config, &pipeline, self.max_seqs).await;
773        let bert_model = get_bert_model(self.enable_search, self.search_bert_model);
774
775        // Create the first MistralRs instance with the first model
776        let mut builder = MistralRsBuilder::new(
777            pipeline,
778            scheduler_config.clone(),
779            !self.interactive_mode,
780            bert_model.clone(),
781        )
782        .with_opt_log(self.log.clone())
783        .with_truncate_sequence(self.truncate_sequence)
784        .with_no_kv_cache(self.no_kv_cache)
785        .with_prefix_cache_n(self.prefix_cache_n);
786
787        // Add MCP client configuration if provided
788        if let Some(mcp_config) = self.mcp_client_config.clone() {
789            builder = builder.with_mcp_client(mcp_config);
790        }
791
792        let mistralrs = builder.build().await;
793
794        // Load additional models
795        for model_config in self.models.iter().skip(1) {
796            info!(
797                "Loading additional model from config key: {}",
798                model_config.model_id
799            );
800
801            let model = model_config.model.clone();
802            let dtype = get_model_dtype(&model)?;
803            let auto_device_map_params = get_auto_device_map_params(&model)?;
804
805            let loader: Box<dyn Loader> = LoaderBuilder::new(model)
806                .with_no_kv_cache(self.no_kv_cache)
807                .with_chat_template(
808                    model_config
809                        .chat_template
810                        .clone()
811                        .or(self.chat_template.clone()),
812                )
813                .with_jinja_explicit(
814                    model_config
815                        .jinja_explicit
816                        .clone()
817                        .or(self.jinja_explicit.clone()),
818                )
819                .build()?;
820
821            let mapper = init_mapper(
822                &model_config
823                    .num_device_layers
824                    .clone()
825                    .or(self.num_device_layers.clone()),
826                &auto_device_map_params,
827            );
828
829            let isq = model_config
830                .in_situ_quant
831                .as_ref()
832                .or(self.in_situ_quant.as_ref())
833                .and_then(|isq| parse_isq_value(isq, Some(&device)).ok());
834
835            let pipeline: LoadedPipeline = loader.load_model_from_hf(
836                None,
837                self.token_source.clone(),
838                &dtype,
839                &device,
840                false,
841                mapper,
842                isq,
843                cache_config,
844            )?;
845
846            // Use the pipeline's name() as the model ID
847            let pipeline_name = pipeline.lock().await.name();
848
849            // Check for model ID conflicts
850            if pipeline_names.contains(&pipeline_name) {
851                anyhow::bail!(
852                    "Model ID conflict: '{}' is already registered. Models from config keys '{}' and previous models have the same pipeline identifier.",
853                    pipeline_name,
854                    model_config.model_id
855                );
856            }
857
858            // Add the model to the MistralRs instance
859            let engine_config = mistralrs_core::EngineConfig {
860                truncate_sequence: self.truncate_sequence,
861                no_kv_cache: self.no_kv_cache,
862                no_prefix_cache: false,
863                prefix_cache_n: self.prefix_cache_n,
864                disable_eos_stop: false,
865                throughput_logging_enabled: !self.interactive_mode,
866                search_embedding_model: bert_model.clone(),
867                search_callback: self.search_callback.clone(),
868                tool_callbacks: HashMap::new(),
869                tool_callbacks_with_tools: HashMap::new(),
870            };
871
872            let mut add_model_config = mistralrs_core::AddModelConfig::new(engine_config);
873            if let Some(mcp_config) = self.mcp_client_config.clone() {
874                add_model_config = add_model_config.with_mcp_config(mcp_config);
875            }
876
877            mistralrs
878                .add_model(
879                    pipeline_name.clone(),
880                    pipeline,
881                    scheduler_config.clone(),
882                    add_model_config,
883                )
884                .await
885                .map_err(|e| anyhow::anyhow!("Failed to add model {}: {}", pipeline_name, e))?;
886
887            info!(
888                "Model `{pipeline_name}` registered successfully (from config key: {})",
889                model_config.model_id
890            );
891            pipeline_names.push(pipeline_name);
892        }
893
894        // Set the default model if specified
895        if let Some(ref default_model_id) = self.default_model_id {
896            mistralrs
897                .set_default_model_id(default_model_id)
898                .map_err(|e| anyhow::anyhow!("Failed to set default model: {}", e))?;
899        }
900
901        // Log all models loaded
902        info!("All models loaded: `{}`", pipeline_names.join("`, `"));
903
904        // Log default model
905        if let Some(ref default_id) = self.default_model_id {
906            info!("Default model: {}", default_id);
907        } else {
908            info!(
909                "Default model: {} (first model, from config key: {})",
910                pipeline_names[0], self.models[0].model_id
911            );
912        }
913        Ok(mistralrs)
914    }
915}
916
917// TODO: replace with best device?
918/// Initializes the device to be used for computation, optionally forcing CPU usage and setting a seed.
919fn init_device(force_cpu: bool, seed: Option<u64>) -> Result<candle_core::Device> {
920    #[cfg(feature = "metal")]
921    let device = if force_cpu {
922        Device::Cpu
923    } else {
924        Device::new_metal(0)?
925    };
926    #[cfg(not(feature = "metal"))]
927    #[allow(clippy::if_same_then_else)]
928    let device = if force_cpu {
929        Device::Cpu
930    } else if mistralrs_core::distributed::use_nccl() {
931        Device::Cpu
932    } else {
933        Device::cuda_if_available(0)?
934    };
935
936    if let Some(seed) = seed {
937        device.set_seed(seed)?;
938    }
939
940    Ok(device)
941}
942
943/// Initializes the device mapping configuration for distributing model layers.
944fn init_mapper(
945    num_device_layers: &Option<Vec<String>>,
946    auto_device_map_params: &AutoDeviceMapParams,
947) -> DeviceMapSetting {
948    // Parse device mapper
949    if let Some(device_layers) = num_device_layers {
950        if device_layers.len() == 1 && device_layers[0].parse::<usize>().is_ok() {
951            let layers = device_layers[0].parse::<usize>().unwrap();
952            DeviceMapSetting::Map(DeviceMapMetadata::from_num_device_layers(vec![
953                DeviceLayerMapMetadata { ordinal: 0, layers },
954            ]))
955        } else {
956            let mut mapping = Vec::new();
957            for layer in device_layers {
958                let split = layer.splitn(2, ':').collect::<Vec<_>>();
959                if split.len() < 2 {
960                    panic!("Expected layer to be of format ORD:NUM, got {layer}");
961                }
962                let ord = split[0]
963                    .parse::<usize>()
964                    .unwrap_or_else(|_| panic!("Failed to parse {} as integer.", split[0]));
965                let num = split[1]
966                    .parse::<usize>()
967                    .unwrap_or_else(|_| panic!("Failed to parse {} as integer.", split[1]));
968                for DeviceLayerMapMetadata { ordinal, layers: _ } in &mapping {
969                    if *ordinal == ord {
970                        panic!("Duplicate ordinal {ord}");
971                    }
972                }
973                mapping.push(DeviceLayerMapMetadata {
974                    ordinal: ord,
975                    layers: num,
976                });
977            }
978            DeviceMapSetting::Map(DeviceMapMetadata::from_num_device_layers(mapping))
979        }
980    } else {
981        DeviceMapSetting::Auto(auto_device_map_params.clone())
982    }
983}
984
985/// Logs hardware feature information and the model's sampling strategy and kind.
986fn mistralrs_instance_info(loader: &dyn Loader) {
987    info!(
988        "avx: {}, neon: {}, simd128: {}, f16c: {}",
989        candle_core::utils::with_avx(),
990        candle_core::utils::with_neon(),
991        candle_core::utils::with_simd128(),
992        candle_core::utils::with_f16c()
993    );
994
995    info!("Sampling method: penalties -> temperature -> topk -> topp -> minp -> multinomial");
996    info!("Model kind is: {}", loader.get_kind().to_string());
997}
998
999/// Determines whether paged attention should be enabled based on device type and preferences.
1000fn configure_paged_attn(device: &Device, paged_attn: Option<bool>) -> bool {
1001    if device.is_cpu() {
1002        if paged_attn == Some(true) {
1003            warn!("Paged attention is not supported on CPU.");
1004        }
1005
1006        defaults::PAGED_ATTN_CPU
1007    } else if device.is_cuda() || mistralrs_core::distributed::use_nccl() {
1008        paged_attn.unwrap_or(defaults::PAGED_ATTN_CUDA)
1009    } else if device.is_metal() {
1010        paged_attn.unwrap_or(defaults::PAGED_ATTN_METAL)
1011    } else {
1012        false
1013    }
1014}
1015
1016/// Initializes the cache configuration for paged attention based on provided parameters.
1017fn init_cache_config(
1018    paged_attn_block_size: Option<usize>,
1019    paged_attn_gpu_mem: Option<usize>,
1020    paged_attn_gpu_mem_usage: Option<f32>,
1021    paged_ctxt_len: Option<usize>,
1022    cache_type: PagedCacheType,
1023    no_paged_attn: bool,
1024    max_seq_len: usize,
1025) -> Result<Option<PagedAttentionConfig>> {
1026    match (
1027        paged_attn_block_size,
1028        paged_attn_gpu_mem,
1029        paged_attn_gpu_mem_usage,
1030        paged_ctxt_len,
1031        paged_attn_supported(),
1032        no_paged_attn,
1033    ) {
1034        (block_size, None, None, None, true, false) => Ok(Some(PagedAttentionConfig::new(
1035            block_size,
1036            MemoryGpuConfig::ContextSize(max_seq_len),
1037            cache_type,
1038        )?)),
1039        (block_size, None, None, Some(ctxt), true, false) => Ok(Some(PagedAttentionConfig::new(
1040            block_size,
1041            MemoryGpuConfig::ContextSize(ctxt),
1042            cache_type,
1043        )?)),
1044        (block_size, None, Some(f), None, true, false) => Ok(Some(PagedAttentionConfig::new(
1045            block_size,
1046            MemoryGpuConfig::Utilization(f),
1047            cache_type,
1048        )?)),
1049        (block_size, Some(m), None, None, true, false) => Ok(Some(PagedAttentionConfig::new(
1050            block_size,
1051            MemoryGpuConfig::MbAmount(m),
1052            cache_type,
1053        )?)),
1054        (block_size, Some(_m), Some(f), None, true, false) => {
1055            info!("Both memory size, and usage were specified, defaulting to the usage value.");
1056            Ok(Some(PagedAttentionConfig::new(
1057                block_size,
1058                MemoryGpuConfig::Utilization(f),
1059                cache_type,
1060            )?))
1061        }
1062        (block_size, Some(_m), None, Some(ctxt), true, false) => {
1063            info!("All memory size and ctxt len, defaulting to the context len value.");
1064            Ok(Some(PagedAttentionConfig::new(
1065                block_size,
1066                MemoryGpuConfig::ContextSize(ctxt),
1067                cache_type,
1068            )?))
1069        }
1070        (block_size, None, Some(f), Some(_ctxt), true, false) => {
1071            info!("Both ctxt len and usage were specified, defaulting to the usage value.");
1072            Ok(Some(PagedAttentionConfig::new(
1073                block_size,
1074                MemoryGpuConfig::Utilization(f),
1075                cache_type,
1076            )?))
1077        }
1078        (_, _, _, _, _, _) => Ok(None),
1079    }
1080}
1081
1082/// Initializes the scheduler configuration based on cache settings and pipeline metadata.
1083async fn init_scheduler_config(
1084    cache_config: &Option<PagedAttentionConfig>,
1085    pipeline: &LoadedPipeline,
1086    args_max_seqs: usize,
1087) -> SchedulerConfig {
1088    if cache_config.is_some() {
1089        // Handle case where we may have device mapping
1090        if let Some(ref cache_config) = pipeline.lock().await.get_metadata().cache_config {
1091            SchedulerConfig::PagedAttentionMeta {
1092                max_num_seqs: args_max_seqs,
1093                config: cache_config.clone(),
1094            }
1095        } else {
1096            SchedulerConfig::DefaultScheduler {
1097                method: DefaultSchedulerMethod::Fixed(args_max_seqs.try_into().unwrap()),
1098            }
1099        }
1100    } else {
1101        SchedulerConfig::DefaultScheduler {
1102            method: DefaultSchedulerMethod::Fixed(args_max_seqs.try_into().unwrap()),
1103        }
1104    }
1105}
1106
1107/// Configures PagedAttention based on two flags.
1108///
1109/// This function resolves the tri-state PagedAttention configuration from
1110/// the mutually exclusive `paged_attn` and `no_paged_attn` flags.
1111pub fn configure_paged_attn_from_flags(
1112    paged_attn: bool,
1113    no_paged_attn: bool,
1114) -> Result<Option<bool>> {
1115    match (paged_attn, no_paged_attn) {
1116        (true, true) => {
1117            anyhow::bail!("Error: `--paged-attn` and `--no-paged-attn` cannot be used together.");
1118        }
1119        (true, false) => Ok(Some(true)),
1120        (false, true) => Ok(Some(false)),
1121        (false, false) => Ok(None),
1122    }
1123}
1124
1125/// Creates a BERT embedding model configuration for search functionality.
1126pub fn get_bert_model(
1127    enable_search: bool,
1128    search_bert_model: Option<String>,
1129) -> Option<BertEmbeddingModel> {
1130    if enable_search {
1131        Some(
1132            search_bert_model
1133                .map(BertEmbeddingModel::Custom)
1134                .unwrap_or_default(),
1135        )
1136    } else {
1137        None
1138    }
1139}