mistralrs_server_core/
mistralrs_for_server_builder.rs

1//! ## mistral.rs instance for server builder.
2
3use std::{num::NonZeroUsize, sync::Arc};
4
5use anyhow::{Context, Result};
6use candle_core::Device;
7use mistralrs_core::{
8    get_auto_device_map_params, get_model_dtype, get_tgt_non_granular_index, paged_attn_supported,
9    parse_isq_value, AutoDeviceMapParams, BertEmbeddingModel, DefaultSchedulerMethod,
10    DeviceLayerMapMetadata, DeviceMapMetadata, DeviceMapSetting, Loader, LoaderBuilder,
11    McpClientConfig, MemoryGpuConfig, MistralRsBuilder, ModelSelected, PagedAttentionConfig,
12    PagedCacheType, SchedulerConfig, SearchCallback, TokenSource,
13};
14use tracing::{info, warn};
15
16use crate::types::{LoadedPipeline, SharedMistralRsState};
17use std::collections::HashMap;
18
19/// Configuration for a single model in a multi-model setup
20#[derive(Clone, serde::Deserialize)]
21pub struct ModelConfig {
22    /// Model identifier (used in API requests)
23    pub model_id: String,
24    /// Model selector
25    pub model: ModelSelected,
26    /// Model-specific chat template
27    pub chat_template: Option<String>,
28    /// Model-specific JINJA template
29    pub jinja_explicit: Option<String>,
30    /// Model-specific device layers
31    pub num_device_layers: Option<Vec<String>>,
32    /// Model-specific in-situ quantization
33    pub in_situ_quant: Option<String>,
34}
35
36impl ModelConfig {
37    pub fn new(model_id: String, model: ModelSelected) -> Self {
38        Self {
39            model_id,
40            model,
41            chat_template: None,
42            jinja_explicit: None,
43            num_device_layers: None,
44            in_situ_quant: None,
45        }
46    }
47
48    pub fn with_chat_template(mut self, chat_template: String) -> Self {
49        self.chat_template = Some(chat_template);
50        self
51    }
52
53    pub fn with_jinja_explicit(mut self, jinja_explicit: String) -> Self {
54        self.jinja_explicit = Some(jinja_explicit);
55        self
56    }
57
58    pub fn with_num_device_layers(mut self, num_device_layers: Vec<String>) -> Self {
59        self.num_device_layers = Some(num_device_layers);
60        self
61    }
62
63    pub fn with_in_situ_quant(mut self, in_situ_quant: String) -> Self {
64        self.in_situ_quant = Some(in_situ_quant);
65        self
66    }
67}
68
69pub mod defaults {
70    //! Provides the default values used for the mistral.rs instance for server.
71    //! These defaults can be used for CLI argument fallbacks, config loading, or general initialization.
72
73    use std::sync::Arc;
74
75    use mistralrs_core::PagedCacheType;
76
77    pub const DEVICE: Option<candle_core::Device> = None;
78    pub const SEED: Option<u64> = None;
79    pub const LOG: Option<String> = None;
80    pub const TRUNCATE_SEQUENCE: bool = false;
81    pub const MODEL: Option<mistralrs_core::ModelSelected> = None;
82    pub const MAX_SEQS: usize = 16;
83    pub const NO_KV_CACHE: bool = false;
84    pub const CHAT_TEMPLATE: Option<String> = None;
85    pub const JINJA_EXPLICIT: Option<String> = None;
86    pub const INTERACTIVE_MODE: bool = false;
87    pub const PREFIX_CACHE_N: usize = 16;
88    pub const NUM_DEVICE_LAYERS: Option<Vec<String>> = None;
89    pub const IN_SITU_QUANT: Option<String> = None;
90    pub const PAGED_ATTN_GPU_MEM: Option<usize> = None;
91    pub const PAGED_ATTN_GPU_MEM_USAGE: Option<f32> = None;
92    pub const PAGED_CTXT_LEN: Option<usize> = None;
93    pub const PAGED_ATTN_BLOCK_SIZE: Option<usize> = None;
94    pub const PAGED_ATTN: Option<bool> = None;
95    pub const PAGED_ATTN_CPU: bool = false;
96    pub const PAGED_ATTN_CUDA: bool = true;
97    pub const PAGED_ATTN_METAL: bool = false;
98    pub const PROMPT_CHUNKSIZE: Option<usize> = None;
99    pub const CPU: bool = false;
100    pub const ENABLE_SEARCH: bool = false;
101    pub const SEARCH_BERT_MODEL: Option<String> = None;
102    pub const TOKEN_SOURCE: mistralrs_core::TokenSource = mistralrs_core::TokenSource::CacheToken;
103    pub const SEARCH_CALLBACK: Option<Arc<mistralrs_core::SearchCallback>> = None;
104    pub const PAGED_CACHE_TYPE: PagedCacheType = PagedCacheType::Auto;
105}
106
107/// A builder for creating a mistral.rs instance with configured options for the mistral.rs server.
108///
109/// ### Examples
110///
111/// Basic usage:
112/// ```ignore
113/// use mistralrs_server_core::mistralrs_for_server_builder::MistralRsForServerBuilder;
114///
115/// let args = Args::parse();
116///
117/// let mistralrs = MistralRsForServerBuilder::new()
118///        .with_truncate_sequence(args.truncate_sequence)
119///        .with_model(args.model)
120///        .with_max_seqs(args.max_seqs)
121///        .with_no_kv_cache(args.no_kv_cache)
122///        .with_token_source(args.token_source)
123///        .with_interactive_mode(args.interactive_mode)
124///        .with_prefix_cache_n(args.prefix_cache_n)
125///        .with_paged_attn(args.paged_attn)
126///        .with_cpu(args.cpu)
127///        .with_enable_search(args.enable_search)
128///        .with_seed_optional(args.seed)
129///        .with_log_optional(args.log)
130///        .with_chat_template_optional(args.chat_template)
131///        .with_jinja_explicit_optional(args.jinja_explicit)
132///        .with_num_device_layers_optional(args.num_device_layers)
133///        .with_in_situ_quant_optional(args.in_situ_quant)
134///        .with_paged_attn_gpu_mem_optional(args.paged_attn_gpu_mem)
135///        .with_paged_attn_gpu_mem_usage_optional(args.paged_attn_gpu_mem_usage)
136///        .with_paged_ctxt_len_optional(args.paged_ctxt_len)
137///        .with_paged_attn_block_size_optional(args.paged_attn_block_size)
138///        .with_prompt_chunksize_optional(args.prompt_chunksize)
139///        .build()
140///        .await?;
141/// ```
142pub struct MistralRsForServerBuilder {
143    /// The Candle device to use for model execution (CPU, CUDA, Metal, etc.).
144    device: Option<Device>,
145
146    /// Integer seed to ensure reproducible random number generation.
147    seed: Option<u64>,
148
149    /// Log all responses and requests to this file
150    log: Option<String>,
151
152    /// If a sequence is larger than the maximum model length, truncate the number
153    /// of tokens such that the sequence will fit at most the maximum length.
154    /// If `max_tokens` is not specified in the request, space for 10 tokens will be reserved instead.
155    truncate_sequence: bool,
156
157    /// Model selector (for single-model mode, deprecated in favor of models)
158    model: Option<ModelSelected>,
159
160    /// Multiple model configurations (for multi-model mode)
161    models: Vec<ModelConfig>,
162
163    /// Default model ID to use when none is specified in requests
164    default_model_id: Option<String>,
165
166    /// Maximum running sequences at any time. If the `tgt_non_granular_index` flag is set for X-LoRA models, this will be set to 1.
167    max_seqs: usize,
168
169    /// Use no KV cache.
170    no_kv_cache: bool,
171
172    /// Chat template file with a JINJA file with `messages`, `add_generation_prompt`, `bos_token`, `eos_token`, and `unk_token` as inputs.
173    /// Used if the automatic deserialization fails. If this ends with `.json` (ie., it is a file) then that template is loaded.
174    chat_template: Option<String>,
175
176    /// Explicit JINJA chat template file (.jinja) to be used. If specified, this overrides all other chat templates.
177    jinja_explicit: Option<String>,
178
179    /// Source of the token for authentication.
180    /// Can be in the formats: `literal:<value>`, `env:<value>`, `path:<value>`, `cache` to use a cached token, or `none` to use no token.
181    /// Defaults to `cache`.
182    token_source: TokenSource,
183
184    /// Enter interactive mode instead of serving a chat server.
185    interactive_mode: bool,
186
187    /// Number of prefix caches to hold on the device. Other caches are evicted to the CPU based on a LRU strategy.
188    prefix_cache_n: usize,
189
190    /// NOTE: This can be omitted to use automatic device mapping!
191    /// Number of device layers to load and run on GPU(s). All others will be on the CPU.
192    /// If one GPU is used, then this value should be an integer. Otherwise, it follows the following pattern:
193    /// ORD:NUM;... Where ORD is a unique device ordinal and NUM is the number of layers for that device.
194    num_device_layers: Option<Vec<String>>,
195
196    /// In-situ quantization to apply.
197    in_situ_quant: Option<String>,
198
199    /// GPU memory to allocate for KV cache with PagedAttention in MBs.
200    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
201    /// The priority is as follows: `pa-ctxt-len` > `pa-gpu-mem-usage` > `pa-gpu-mem`.
202    paged_attn_gpu_mem: Option<usize>,
203
204    /// Percentage of GPU memory to utilize after allocation of KV cache with PagedAttention, from 0 to 1.
205    /// If this is not set and the device is CUDA, it will default to `0.9`.
206    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
207    /// The priority is as follows: `pa-ctxt-len` > `pa-gpu-mem-usage` > `pa-gpu-mem`.
208    paged_attn_gpu_mem_usage: Option<f32>,
209
210    /// Total context length to allocate the KV cache for (total number of tokens which the KV cache can hold).
211    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
212    /// The priority is as follows: `pa-ctxt-len` > `pa-gpu-mem-usage` > `pa-gpu-mem`.
213    /// This is the default setting, and it defaults to the `max-seq-len` specified in after the model type.
214    paged_ctxt_len: Option<usize>,
215
216    /// Block size (number of tokens per block) for PagedAttention. If this is not set and the device is CUDA, it will default to 32.
217    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
218    paged_attn_block_size: Option<usize>,
219
220    /// Enables or disables PagedAttention. By default, PagedAttention will be enabled for CUDA and disabled for Metal (and is not supported for CPU). Use this to override the default behavior.
221    paged_attn: Option<bool>,
222
223    /// Number of tokens to batch the prompt step into. This can help with OOM errors when in the prompt step, but reduces performance.
224    prompt_chunksize: Option<usize>,
225
226    /// Use CPU only
227    cpu: bool,
228
229    /// Enable searching compatible with the OpenAI `web_search_options` setting. This uses the BERT model specified below or the default.
230    enable_search: bool,
231
232    /// Specify a Hugging Face model ID for a BERT model to assist web searching. Defaults to Snowflake Arctic Embed L.
233    search_bert_model: Option<String>,
234
235    /// Optional override search callback
236    search_callback: Option<Arc<SearchCallback>>,
237
238    /// Optional MCP client configuration
239    mcp_client_config: Option<McpClientConfig>,
240
241    /// PagedAttention KV cache type
242    paged_cache_type: PagedCacheType,
243}
244
245impl Default for MistralRsForServerBuilder {
246    /// Creates a new builder with default configuration.
247    fn default() -> Self {
248        Self {
249            device: defaults::DEVICE,
250            seed: defaults::SEED,
251            log: defaults::LOG,
252            truncate_sequence: defaults::TRUNCATE_SEQUENCE,
253            model: defaults::MODEL,
254            models: Vec::new(),
255            default_model_id: None,
256            max_seqs: defaults::MAX_SEQS,
257            no_kv_cache: defaults::NO_KV_CACHE,
258            chat_template: defaults::CHAT_TEMPLATE,
259            jinja_explicit: defaults::JINJA_EXPLICIT,
260            token_source: defaults::TOKEN_SOURCE,
261            interactive_mode: defaults::INTERACTIVE_MODE,
262            prefix_cache_n: defaults::PREFIX_CACHE_N,
263            num_device_layers: defaults::NUM_DEVICE_LAYERS,
264            in_situ_quant: defaults::IN_SITU_QUANT,
265            paged_attn_gpu_mem: defaults::PAGED_ATTN_GPU_MEM,
266            paged_attn_gpu_mem_usage: defaults::PAGED_ATTN_GPU_MEM_USAGE,
267            paged_ctxt_len: defaults::PAGED_CTXT_LEN,
268            paged_attn_block_size: defaults::PAGED_ATTN_BLOCK_SIZE,
269            paged_attn: defaults::PAGED_ATTN,
270            prompt_chunksize: defaults::PROMPT_CHUNKSIZE,
271            cpu: defaults::CPU,
272            enable_search: defaults::ENABLE_SEARCH,
273            search_bert_model: defaults::SEARCH_BERT_MODEL,
274            search_callback: defaults::SEARCH_CALLBACK,
275            mcp_client_config: None,
276            paged_cache_type: defaults::PAGED_CACHE_TYPE,
277        }
278    }
279}
280
281impl MistralRsForServerBuilder {
282    /// Creates a new `MistralRsForServerBuilder` with default settings.
283    ///
284    /// This is equivalent to calling `Default::default()`.
285    ///
286    /// ### Examples
287    ///
288    /// ```ignore
289    /// use mistralrs_server_core::mistralrs_for_server_builder::MistralRsForServerBuilder;
290    ///
291    /// let builder = mistralrs_server_core::mistralrs_for_server_builder::MistralRsForServerBuilder::new();
292    /// ```
293    pub fn new() -> Self {
294        Default::default()
295    }
296
297    /// Sets the Candle device to use for model execution.
298    pub fn with_device(mut self, device: Device) -> Self {
299        self.device = Some(device);
300        self
301    }
302
303    /// Sets the random seed for deterministic model behavior.
304    pub fn with_seed(mut self, seed: u64) -> Self {
305        self.seed = Some(seed);
306        self
307    }
308
309    /// Sets the random seed if provided.
310    pub fn with_seed_optional(mut self, seed: Option<u64>) -> Self {
311        if let Some(seed) = seed {
312            self = self.with_seed(seed);
313        }
314        self
315    }
316
317    /// Sets the logging configuration.
318    pub fn with_log(mut self, log: String) -> Self {
319        self.log = Some(log);
320        self
321    }
322
323    /// Sets the logging configuration if provided.
324    pub fn with_log_optional(mut self, log: Option<String>) -> Self {
325        if let Some(log) = log {
326            self = self.with_log(log);
327        }
328        self
329    }
330
331    /// Sets whether to truncate sequences that exceed the maximum model length.
332    pub fn with_truncate_sequence(mut self, truncate_sequence: bool) -> Self {
333        self.truncate_sequence = truncate_sequence;
334        self
335    }
336
337    /// Sets the model to be used.
338    pub fn with_model(mut self, model: ModelSelected) -> Self {
339        self.model = Some(model);
340        self
341    }
342
343    /// Add a model to the multi-model configuration.
344    pub fn with_model_config(mut self, model_config: ModelConfig) -> Self {
345        self.models.push(model_config);
346        self
347    }
348
349    /// Add multiple models to the multi-model configuration.
350    pub fn with_model_configs(mut self, model_configs: Vec<ModelConfig>) -> Self {
351        self.models.extend(model_configs);
352        self
353    }
354
355    /// Set the default model ID to use when none is specified in requests.
356    pub fn with_default_model_id(mut self, default_model_id: String) -> Self {
357        self.default_model_id = Some(default_model_id);
358        self
359    }
360
361    /// Add a model configuration.
362    pub fn add_model_config(mut self, config: ModelConfig) -> Self {
363        self.models.push(config);
364        self
365    }
366
367    /// Add a model with just an ID and ModelSelected (convenience method).
368    pub fn add_model(mut self, model_id: String, model: ModelSelected) -> Self {
369        self.models.push(ModelConfig::new(model_id, model));
370        self
371    }
372
373    /// Sets the maximum number of concurrent sequences.
374    pub fn with_max_seqs(mut self, max_seqs: usize) -> Self {
375        self.max_seqs = max_seqs;
376        self
377    }
378
379    /// Sets whether to disable the key-value cache.
380    pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
381        self.no_kv_cache = no_kv_cache;
382        self
383    }
384
385    /// Sets the chat template configuration.
386    pub fn with_chat_template(mut self, chat_template: String) -> Self {
387        self.chat_template = Some(chat_template);
388        self
389    }
390
391    /// Sets the chat template configuration if provided.
392    pub fn with_chat_template_optional(mut self, chat_template: Option<String>) -> Self {
393        if let Some(chat_template) = chat_template {
394            self = self.with_chat_template(chat_template);
395        }
396        self
397    }
398
399    /// Sets an explicit JINJA chat template file.
400    pub fn with_jinja_explicit(mut self, jinja_explicit: String) -> Self {
401        self.jinja_explicit = Some(jinja_explicit);
402        self
403    }
404
405    /// Sets an explicit JINJA chat template file if provided.
406    pub fn with_jinja_explicit_optional(mut self, jinja_explicit: Option<String>) -> Self {
407        if let Some(jinja_explicit) = jinja_explicit {
408            self = self.with_jinja_explicit(jinja_explicit);
409        }
410        self
411    }
412
413    /// Sets the token source for authentication.
414    pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
415        self.token_source = token_source;
416        self
417    }
418
419    /// Sets whether to run in interactive mode.
420    pub fn with_interactive_mode(mut self, interactive_mode: bool) -> Self {
421        self.interactive_mode = interactive_mode;
422        self
423    }
424
425    /// Sets the number of prefix caches to hold on the device.
426    pub fn with_prefix_cache_n(mut self, prefix_cache_n: usize) -> Self {
427        self.prefix_cache_n = prefix_cache_n;
428        self
429    }
430
431    /// Sets the device layer mapping
432    pub fn with_num_device_layers(mut self, num_device_layers: Vec<String>) -> Self {
433        self.num_device_layers = Some(num_device_layers);
434        self
435    }
436
437    /// Sets the device layer mapping if provided.
438    pub fn with_num_device_layers_optional(
439        mut self,
440        num_device_layers: Option<Vec<String>>,
441    ) -> Self {
442        if let Some(num_device_layers) = num_device_layers {
443            self = self.with_num_device_layers(num_device_layers);
444        }
445        self
446    }
447
448    /// Sets the in-situ quantization method.
449    pub fn with_in_situ_quant(mut self, in_situ_quant: String) -> Self {
450        self.in_situ_quant = Some(in_situ_quant);
451        self
452    }
453
454    /// Sets the in-situ quantization method if provided.
455    pub fn with_in_situ_quant_optional(mut self, in_situ_quant: Option<String>) -> Self {
456        if let Some(in_situ_quant) = in_situ_quant {
457            self = self.with_in_situ_quant(in_situ_quant);
458        }
459        self
460    }
461
462    /// Sets PagedAttention.
463    ///
464    /// Unlike other `with_PROP` or `with_PROP_optional` methods, this method
465    /// sets the value to whatever `Option<bool>` is passed in as `None`, `Some(true)`
466    /// and `Some(false)` have different implications.
467    ///
468    /// `None`: default behavior for target device (e.g. enable for CUDA, disable for Metal)
469    /// `Some(true)`: enable (if supported by target device)
470    /// `Some(false)`: disable
471    pub fn set_paged_attn(mut self, paged_attn: Option<bool>) -> Self {
472        self.paged_attn = paged_attn;
473        self
474    }
475
476    /// Sets the GPU memory allocation for PagedAttention KV cache.
477    pub fn with_paged_attn_gpu_mem(mut self, paged_attn_gpu_mem: usize) -> Self {
478        self.paged_attn_gpu_mem = Some(paged_attn_gpu_mem);
479        self
480    }
481
482    /// Sets the GPU memory allocation for PagedAttention KV cache if provided.
483    pub fn with_paged_attn_gpu_mem_optional(mut self, paged_attn_gpu_mem: Option<usize>) -> Self {
484        if let Some(paged_attn_gpu_mem) = paged_attn_gpu_mem {
485            self = self.with_paged_attn_gpu_mem(paged_attn_gpu_mem);
486        }
487        self
488    }
489
490    /// Sets the percentage of GPU memory to utilize for PagedAttention.
491    pub fn with_paged_attn_gpu_mem_usage(mut self, paged_attn_gpu_mem_usage: f32) -> Self {
492        self.paged_attn_gpu_mem_usage = Some(paged_attn_gpu_mem_usage);
493        self
494    }
495
496    /// Sets the percentage of GPU memory to utilize for PagedAttention if provided.
497    pub fn with_paged_attn_gpu_mem_usage_optional(
498        mut self,
499        paged_attn_gpu_mem_usage: Option<f32>,
500    ) -> Self {
501        if let Some(paged_attn_gpu_mem_usage) = paged_attn_gpu_mem_usage {
502            self = self.with_paged_attn_gpu_mem_usage(paged_attn_gpu_mem_usage);
503        }
504        self
505    }
506
507    /// Sets the total context length for KV cache allocation.
508    pub fn with_paged_ctxt_len(mut self, paged_ctxt_len: usize) -> Self {
509        self.paged_ctxt_len = Some(paged_ctxt_len);
510        self
511    }
512
513    /// Sets the total context length for KV cache allocation if provided.
514    pub fn with_paged_ctxt_len_optional(mut self, paged_ctxt_len: Option<usize>) -> Self {
515        if let Some(paged_ctxt_len) = paged_ctxt_len {
516            self = self.with_paged_ctxt_len(paged_ctxt_len);
517        }
518        self
519    }
520
521    /// Sets the block size for PagedAttention.
522    pub fn with_paged_attn_block_size(mut self, paged_attn_block_size: usize) -> Self {
523        self.paged_attn_block_size = Some(paged_attn_block_size);
524        self
525    }
526
527    /// Sets the block size for PagedAttention.
528    pub fn with_paged_attn_cache_type(mut self, cache_type: PagedCacheType) -> Self {
529        self.paged_cache_type = cache_type;
530        self
531    }
532
533    /// Sets the block size for PagedAttention if provided.
534    pub fn with_paged_attn_block_size_optional(
535        mut self,
536        paged_attn_block_size: Option<usize>,
537    ) -> Self {
538        if let Some(paged_attn_block_size) = paged_attn_block_size {
539            self = self.with_paged_attn_block_size(paged_attn_block_size);
540        }
541        self
542    }
543
544    /// Sets the prompt chunking size.
545    pub fn with_prompt_chunksize(mut self, prompt_chunksize: usize) -> Self {
546        self.prompt_chunksize = Some(prompt_chunksize);
547        self
548    }
549
550    /// Sets the prompt chunking size if provided.
551    pub fn with_prompt_chunksize_optional(mut self, prompt_chunksize: Option<usize>) -> Self {
552        if let Some(prompt_chunksize) = prompt_chunksize {
553            self = self.with_prompt_chunksize(prompt_chunksize);
554        }
555        self
556    }
557
558    /// Sets whether to force CPU-only execution.
559    pub fn with_cpu(mut self, cpu: bool) -> Self {
560        self.cpu = cpu;
561        self
562    }
563
564    /// Sets whether to enable web search functionality.
565    pub fn with_enable_search(mut self, enable_search: bool) -> Self {
566        self.enable_search = enable_search;
567        self
568    }
569
570    /// Sets the BERT model for web search assistance.
571    pub fn with_search_bert_model(mut self, search_bert_model: String) -> Self {
572        self.search_bert_model = Some(search_bert_model);
573        self
574    }
575
576    /// Override the search function used when `web_search_options` is enabled.
577    pub fn with_search_callback(mut self, callback: Arc<SearchCallback>) -> Self {
578        self.search_callback = Some(callback);
579        self
580    }
581
582    /// Sets the MCP client configuration.
583    pub fn with_mcp_config(mut self, mcp_config: McpClientConfig) -> Self {
584        self.mcp_client_config = Some(mcp_config);
585        self
586    }
587
588    /// Sets the MCP client configuration if provided.
589    pub fn with_mcp_config_optional(mut self, mcp_config: Option<McpClientConfig>) -> Self {
590        if let Some(mcp_config) = mcp_config {
591            self = self.with_mcp_config(mcp_config);
592        }
593        self
594    }
595
596    /// Builds the configured mistral.rs instance.
597    ///
598    /// ### Examples
599    ///
600    /// ```ignore
601    /// use mistralrs_server_core::mistralrs_for_server_builder::MistralRsForServerBuilder;
602    ///
603    /// let shared_mistralrs = MistralRsForServerBuilder::new()
604    ///     .with_model(model)
605    ///     .with_in_situ_quant("8".to_string())
606    ///     .set_paged_attn(Some(true))
607    ///     .build()
608    ///     .await?;
609    /// ```
610    pub async fn build(self) -> Result<SharedMistralRsState> {
611        // Determine if we're in single-model or multi-model mode
612        if !self.models.is_empty() {
613            self.build_multi_model().await
614        } else {
615            self.build_single_model().await
616        }
617    }
618
619    /// Build a single-model instance (legacy mode)
620    async fn build_single_model(mut self) -> Result<SharedMistralRsState> {
621        let model = self.model.context("Model was None")?;
622
623        let tgt_non_granular_index = get_tgt_non_granular_index(&model);
624        let dtype = get_model_dtype(&model)?;
625        let auto_device_map_params = get_auto_device_map_params(&model)?;
626
627        if tgt_non_granular_index.is_some() {
628            self.max_seqs = 1;
629        }
630
631        let prompt_chunksize = match self.prompt_chunksize {
632            Some(0) => {
633                anyhow::bail!("`prompt_chunksize` must be a strictly positive integer, got 0.",)
634            }
635            Some(x) => Some(NonZeroUsize::new(x).unwrap()),
636            None => None,
637        };
638
639        let max_seq_len = auto_device_map_params.max_seq_len();
640
641        let device = if let Some(device) = self.device {
642            device
643        } else {
644            init_device(self.cpu, self.seed)?
645        };
646
647        let mapper = init_mapper(&self.num_device_layers, &auto_device_map_params);
648        let paged_attn = configure_paged_attn(&device, self.paged_attn);
649
650        // Allocate 0.5 GB of CPU memory just as a placeholder.
651        // Nothing happens here as we have no `swap_out`, see `_preempt_by_swap`.
652        let cache_config = init_cache_config(
653            self.paged_attn_block_size,
654            self.paged_attn_gpu_mem,
655            self.paged_attn_gpu_mem_usage,
656            self.paged_ctxt_len,
657            self.paged_cache_type,
658            !paged_attn,
659            max_seq_len,
660        )?;
661
662        // Configure this last to prevent arg moves
663        let loader: Box<dyn Loader> = LoaderBuilder::new(model)
664            .with_no_kv_cache(self.no_kv_cache)
665            .with_chat_template(self.chat_template)
666            .with_prompt_chunksize(prompt_chunksize)
667            .with_jinja_explicit(self.jinja_explicit)
668            .build()?;
669
670        mistralrs_instance_info(&*loader);
671
672        let isq = self
673            .in_situ_quant
674            .as_ref()
675            .and_then(|isq| parse_isq_value(isq, Some(&device)).ok());
676
677        let pipeline: LoadedPipeline = loader.load_model_from_hf(
678            None,
679            self.token_source,
680            &dtype,
681            &device,
682            false,
683            mapper,
684            isq,
685            cache_config,
686        )?;
687        info!("Model loaded.");
688
689        let scheduler_config = init_scheduler_config(&cache_config, &pipeline, self.max_seqs).await;
690
691        let bert_model = get_bert_model(self.enable_search, self.search_bert_model);
692
693        let mut builder = MistralRsBuilder::new(
694            pipeline,
695            scheduler_config,
696            !self.interactive_mode,
697            bert_model,
698        )
699        .with_opt_log(self.log)
700        .with_truncate_sequence(self.truncate_sequence)
701        .with_no_kv_cache(self.no_kv_cache)
702        .with_prefix_cache_n(self.prefix_cache_n);
703
704        // Add MCP client configuration if provided
705        if let Some(mcp_config) = self.mcp_client_config {
706            builder = builder.with_mcp_client(mcp_config);
707        }
708
709        let mistralrs = builder.build().await;
710
711        Ok(mistralrs)
712    }
713
714    /// Build a multi-model instance
715    pub async fn build_multi_model(mut self) -> Result<SharedMistralRsState> {
716        if self.models.is_empty() {
717            anyhow::bail!("No models configured for multi-model mode");
718        }
719
720        // Use the first model as the base configuration
721        let first_model = &self.models[0];
722        let model = first_model.model.clone();
723
724        let tgt_non_granular_index = get_tgt_non_granular_index(&model);
725        let dtype = get_model_dtype(&model)?;
726        let auto_device_map_params = get_auto_device_map_params(&model)?;
727
728        if tgt_non_granular_index.is_some() {
729            self.max_seqs = 1;
730        }
731
732        let prompt_chunksize = match self.prompt_chunksize {
733            Some(0) => {
734                anyhow::bail!("`prompt_chunksize` must be a strictly positive integer, got 0.",)
735            }
736            Some(x) => Some(NonZeroUsize::new(x).unwrap()),
737            None => None,
738        };
739
740        let max_seq_len = auto_device_map_params.max_seq_len();
741
742        let device = if let Some(device) = self.device {
743            device
744        } else {
745            init_device(self.cpu, self.seed)?
746        };
747
748        // Create the first model's pipeline
749        let loader: Box<dyn Loader> = LoaderBuilder::new(model)
750            .with_no_kv_cache(self.no_kv_cache)
751            .with_chat_template(
752                first_model
753                    .chat_template
754                    .clone()
755                    .or(self.chat_template.clone()),
756            )
757            .with_prompt_chunksize(prompt_chunksize)
758            .with_jinja_explicit(
759                first_model
760                    .jinja_explicit
761                    .clone()
762                    .or(self.jinja_explicit.clone()),
763            )
764            .build()?;
765
766        mistralrs_instance_info(&*loader);
767
768        let mapper = init_mapper(
769            &first_model
770                .num_device_layers
771                .clone()
772                .or(self.num_device_layers.clone()),
773            &auto_device_map_params,
774        );
775        let paged_attn = configure_paged_attn(&device, self.paged_attn);
776
777        let cache_config = init_cache_config(
778            self.paged_attn_block_size,
779            self.paged_attn_gpu_mem,
780            self.paged_attn_gpu_mem_usage,
781            self.paged_ctxt_len,
782            self.paged_cache_type,
783            !paged_attn,
784            max_seq_len,
785        )?;
786
787        let isq = first_model
788            .in_situ_quant
789            .as_ref()
790            .or(self.in_situ_quant.as_ref())
791            .and_then(|isq| parse_isq_value(isq, Some(&device)).ok());
792
793        let mut pipeline_names = Vec::new();
794
795        let pipeline: LoadedPipeline = loader.load_model_from_hf(
796            None,
797            self.token_source.clone(),
798            &dtype,
799            &device,
800            false,
801            mapper,
802            isq,
803            cache_config,
804        )?;
805        let first_pipeline_name = pipeline.lock().await.name();
806        info!(
807            "First model loaded: `{first_pipeline_name}` (from config key: {})",
808            first_model.model_id
809        );
810        pipeline_names.push(first_pipeline_name);
811
812        let scheduler_config = init_scheduler_config(&cache_config, &pipeline, self.max_seqs).await;
813        let bert_model = get_bert_model(self.enable_search, self.search_bert_model);
814
815        // Create the first MistralRs instance with the first model
816        let mut builder = MistralRsBuilder::new(
817            pipeline,
818            scheduler_config.clone(),
819            !self.interactive_mode,
820            bert_model.clone(),
821        )
822        .with_opt_log(self.log.clone())
823        .with_truncate_sequence(self.truncate_sequence)
824        .with_no_kv_cache(self.no_kv_cache)
825        .with_prefix_cache_n(self.prefix_cache_n);
826
827        // Add MCP client configuration if provided
828        if let Some(mcp_config) = self.mcp_client_config.clone() {
829            builder = builder.with_mcp_client(mcp_config);
830        }
831
832        let mistralrs = builder.build().await;
833
834        // Load additional models
835        for model_config in self.models.iter().skip(1) {
836            info!(
837                "Loading additional model from config key: {}",
838                model_config.model_id
839            );
840
841            let model = model_config.model.clone();
842            let dtype = get_model_dtype(&model)?;
843            let auto_device_map_params = get_auto_device_map_params(&model)?;
844
845            let loader: Box<dyn Loader> = LoaderBuilder::new(model)
846                .with_no_kv_cache(self.no_kv_cache)
847                .with_chat_template(
848                    model_config
849                        .chat_template
850                        .clone()
851                        .or(self.chat_template.clone()),
852                )
853                .with_prompt_chunksize(prompt_chunksize)
854                .with_jinja_explicit(
855                    model_config
856                        .jinja_explicit
857                        .clone()
858                        .or(self.jinja_explicit.clone()),
859                )
860                .build()?;
861
862            let mapper = init_mapper(
863                &model_config
864                    .num_device_layers
865                    .clone()
866                    .or(self.num_device_layers.clone()),
867                &auto_device_map_params,
868            );
869
870            let isq = model_config
871                .in_situ_quant
872                .as_ref()
873                .or(self.in_situ_quant.as_ref())
874                .and_then(|isq| parse_isq_value(isq, Some(&device)).ok());
875
876            let pipeline: LoadedPipeline = loader.load_model_from_hf(
877                None,
878                self.token_source.clone(),
879                &dtype,
880                &device,
881                false,
882                mapper,
883                isq,
884                cache_config,
885            )?;
886
887            // Use the pipeline's name() as the model ID
888            let pipeline_name = pipeline.lock().await.name();
889
890            // Check for model ID conflicts
891            if pipeline_names.contains(&pipeline_name) {
892                anyhow::bail!(
893                    "Model ID conflict: '{}' is already registered. Models from config keys '{}' and previous models have the same pipeline identifier.",
894                    pipeline_name,
895                    model_config.model_id
896                );
897            }
898
899            // Add the model to the MistralRs instance
900            let engine_config = mistralrs_core::EngineConfig {
901                truncate_sequence: self.truncate_sequence,
902                no_kv_cache: self.no_kv_cache,
903                no_prefix_cache: false,
904                prefix_cache_n: self.prefix_cache_n,
905                disable_eos_stop: false,
906                throughput_logging_enabled: !self.interactive_mode,
907                search_embedding_model: bert_model.clone(),
908                search_callback: self.search_callback.clone(),
909                tool_callbacks: HashMap::new(),
910                tool_callbacks_with_tools: HashMap::new(),
911            };
912
913            let mut add_model_config = mistralrs_core::AddModelConfig::new(engine_config);
914            if let Some(mcp_config) = self.mcp_client_config.clone() {
915                add_model_config = add_model_config.with_mcp_config(mcp_config);
916            }
917
918            mistralrs
919                .add_model(
920                    pipeline_name.clone(),
921                    pipeline,
922                    scheduler_config.clone(),
923                    add_model_config,
924                )
925                .await
926                .map_err(|e| anyhow::anyhow!("Failed to add model {}: {}", pipeline_name, e))?;
927
928            info!(
929                "Model `{pipeline_name}` registered successfully (from config key: {})",
930                model_config.model_id
931            );
932            pipeline_names.push(pipeline_name);
933        }
934
935        // Set the default model if specified
936        if let Some(ref default_model_id) = self.default_model_id {
937            mistralrs
938                .set_default_model_id(default_model_id)
939                .map_err(|e| anyhow::anyhow!("Failed to set default model: {}", e))?;
940        }
941
942        // Log all models loaded
943        info!("All models loaded: `{}`", pipeline_names.join("`, `"));
944
945        // Log default model
946        if let Some(ref default_id) = self.default_model_id {
947            info!("Default model: {}", default_id);
948        } else {
949            info!(
950                "Default model: {} (first model, from config key: {})",
951                pipeline_names[0], self.models[0].model_id
952            );
953        }
954        Ok(mistralrs)
955    }
956}
957
958// TODO: replace with best device?
959/// Initializes the device to be used for computation, optionally forcing CPU usage and setting a seed.
960fn init_device(force_cpu: bool, seed: Option<u64>) -> Result<candle_core::Device> {
961    #[cfg(feature = "metal")]
962    let device = if force_cpu {
963        Device::Cpu
964    } else {
965        Device::new_metal(0)?
966    };
967    #[cfg(not(feature = "metal"))]
968    #[allow(clippy::if_same_then_else)]
969    let device = if force_cpu {
970        Device::Cpu
971    } else if mistralrs_core::distributed::use_nccl() {
972        Device::Cpu
973    } else {
974        Device::cuda_if_available(0)?
975    };
976
977    if let Some(seed) = seed {
978        device.set_seed(seed)?;
979    }
980
981    Ok(device)
982}
983
984/// Initializes the device mapping configuration for distributing model layers.
985fn init_mapper(
986    num_device_layers: &Option<Vec<String>>,
987    auto_device_map_params: &AutoDeviceMapParams,
988) -> DeviceMapSetting {
989    // Parse device mapper
990    if let Some(device_layers) = num_device_layers {
991        if device_layers.len() == 1 && device_layers[0].parse::<usize>().is_ok() {
992            let layers = device_layers[0].parse::<usize>().unwrap();
993            DeviceMapSetting::Map(DeviceMapMetadata::from_num_device_layers(vec![
994                DeviceLayerMapMetadata { ordinal: 0, layers },
995            ]))
996        } else {
997            let mut mapping = Vec::new();
998            for layer in device_layers {
999                let split = layer.splitn(2, ':').collect::<Vec<_>>();
1000                if split.len() < 2 {
1001                    panic!("Expected layer to be of format ORD:NUM, got {layer}");
1002                }
1003                let ord = split[0]
1004                    .parse::<usize>()
1005                    .unwrap_or_else(|_| panic!("Failed to parse {} as integer.", split[0]));
1006                let num = split[1]
1007                    .parse::<usize>()
1008                    .unwrap_or_else(|_| panic!("Failed to parse {} as integer.", split[1]));
1009                for DeviceLayerMapMetadata { ordinal, layers: _ } in &mapping {
1010                    if *ordinal == ord {
1011                        panic!("Duplicate ordinal {ord}");
1012                    }
1013                }
1014                mapping.push(DeviceLayerMapMetadata {
1015                    ordinal: ord,
1016                    layers: num,
1017                });
1018            }
1019            DeviceMapSetting::Map(DeviceMapMetadata::from_num_device_layers(mapping))
1020        }
1021    } else {
1022        DeviceMapSetting::Auto(auto_device_map_params.clone())
1023    }
1024}
1025
1026/// Logs hardware feature information and the model's sampling strategy and kind.
1027fn mistralrs_instance_info(loader: &dyn Loader) {
1028    info!(
1029        "avx: {}, neon: {}, simd128: {}, f16c: {}",
1030        candle_core::utils::with_avx(),
1031        candle_core::utils::with_neon(),
1032        candle_core::utils::with_simd128(),
1033        candle_core::utils::with_f16c()
1034    );
1035
1036    info!("Sampling method: penalties -> temperature -> topk -> topp -> minp -> multinomial");
1037    info!("Model kind is: {}", loader.get_kind().to_string());
1038}
1039
1040/// Determines whether paged attention should be enabled based on device type and preferences.
1041fn configure_paged_attn(device: &Device, paged_attn: Option<bool>) -> bool {
1042    if device.is_cpu() {
1043        if paged_attn == Some(true) {
1044            warn!("Paged attention is not supported on CPU.");
1045        }
1046
1047        defaults::PAGED_ATTN_CPU
1048    } else if device.is_cuda() || mistralrs_core::distributed::use_nccl() {
1049        paged_attn.unwrap_or(defaults::PAGED_ATTN_CUDA)
1050    } else if device.is_metal() {
1051        paged_attn.unwrap_or(defaults::PAGED_ATTN_METAL)
1052    } else {
1053        false
1054    }
1055}
1056
1057/// Initializes the cache configuration for paged attention based on provided parameters.
1058fn init_cache_config(
1059    paged_attn_block_size: Option<usize>,
1060    paged_attn_gpu_mem: Option<usize>,
1061    paged_attn_gpu_mem_usage: Option<f32>,
1062    paged_ctxt_len: Option<usize>,
1063    cache_type: PagedCacheType,
1064    no_paged_attn: bool,
1065    max_seq_len: usize,
1066) -> Result<Option<PagedAttentionConfig>> {
1067    match (
1068        paged_attn_block_size,
1069        paged_attn_gpu_mem,
1070        paged_attn_gpu_mem_usage,
1071        paged_ctxt_len,
1072        paged_attn_supported(),
1073        no_paged_attn,
1074    ) {
1075        (block_size, None, None, None, true, false) => Ok(Some(PagedAttentionConfig::new(
1076            block_size,
1077            512,
1078            MemoryGpuConfig::ContextSize(max_seq_len),
1079            cache_type,
1080        )?)),
1081        (block_size, None, None, Some(ctxt), true, false) => Ok(Some(PagedAttentionConfig::new(
1082            block_size,
1083            512,
1084            MemoryGpuConfig::ContextSize(ctxt),
1085            cache_type,
1086        )?)),
1087        (block_size, None, Some(f), None, true, false) => Ok(Some(PagedAttentionConfig::new(
1088            block_size,
1089            512,
1090            MemoryGpuConfig::Utilization(f),
1091            cache_type,
1092        )?)),
1093        (block_size, Some(m), None, None, true, false) => Ok(Some(PagedAttentionConfig::new(
1094            block_size,
1095            512,
1096            MemoryGpuConfig::MbAmount(m),
1097            cache_type,
1098        )?)),
1099        (block_size, Some(_m), Some(f), None, true, false) => {
1100            info!("Both memory size, and usage were specified, defaulting to the usage value.");
1101            Ok(Some(PagedAttentionConfig::new(
1102                block_size,
1103                512,
1104                MemoryGpuConfig::Utilization(f),
1105                cache_type,
1106            )?))
1107        }
1108        (block_size, Some(_m), None, Some(ctxt), true, false) => {
1109            info!("All memory size and ctxt len, defaulting to the context len value.");
1110            Ok(Some(PagedAttentionConfig::new(
1111                block_size,
1112                512,
1113                MemoryGpuConfig::ContextSize(ctxt),
1114                cache_type,
1115            )?))
1116        }
1117        (block_size, None, Some(f), Some(_ctxt), true, false) => {
1118            info!("Both ctxt len and usage were specified, defaulting to the usage value.");
1119            Ok(Some(PagedAttentionConfig::new(
1120                block_size,
1121                512,
1122                MemoryGpuConfig::Utilization(f),
1123                cache_type,
1124            )?))
1125        }
1126        (_, _, _, _, _, _) => Ok(None),
1127    }
1128}
1129
1130/// Initializes the scheduler configuration based on cache settings and pipeline metadata.
1131async fn init_scheduler_config(
1132    cache_config: &Option<PagedAttentionConfig>,
1133    pipeline: &LoadedPipeline,
1134    args_max_seqs: usize,
1135) -> SchedulerConfig {
1136    if cache_config.is_some() {
1137        // Handle case where we may have device mapping
1138        if let Some(ref cache_config) = pipeline.lock().await.get_metadata().cache_config {
1139            SchedulerConfig::PagedAttentionMeta {
1140                max_num_seqs: args_max_seqs,
1141                config: cache_config.clone(),
1142            }
1143        } else {
1144            SchedulerConfig::DefaultScheduler {
1145                method: DefaultSchedulerMethod::Fixed(args_max_seqs.try_into().unwrap()),
1146            }
1147        }
1148    } else {
1149        SchedulerConfig::DefaultScheduler {
1150            method: DefaultSchedulerMethod::Fixed(args_max_seqs.try_into().unwrap()),
1151        }
1152    }
1153}
1154
1155/// Configures PagedAttention based on two flags.
1156///
1157/// This function resolves the tri-state PagedAttention configuration from
1158/// the mutually exclusive `paged_attn` and `no_paged_attn` flags.
1159pub fn configure_paged_attn_from_flags(
1160    paged_attn: bool,
1161    no_paged_attn: bool,
1162) -> Result<Option<bool>> {
1163    match (paged_attn, no_paged_attn) {
1164        (true, true) => {
1165            anyhow::bail!("Error: `--paged-attn` and `--no-paged-attn` cannot be used together.");
1166        }
1167        (true, false) => Ok(Some(true)),
1168        (false, true) => Ok(Some(false)),
1169        (false, false) => Ok(None),
1170    }
1171}
1172
1173/// Creates a BERT embedding model configuration for search functionality.
1174pub fn get_bert_model(
1175    enable_search: bool,
1176    search_bert_model: Option<String>,
1177) -> Option<BertEmbeddingModel> {
1178    if enable_search {
1179        Some(
1180            search_bert_model
1181                .map(BertEmbeddingModel::Custom)
1182                .unwrap_or_default(),
1183        )
1184    } else {
1185        None
1186    }
1187}