mistralrs_server_core/
mistralrs_for_server_builder.rs

1//! ## mistral.rs instance for server builder.
2
3use std::{num::NonZeroUsize, sync::Arc};
4
5use anyhow::{Context, Result};
6use candle_core::Device;
7use mistralrs_core::{
8    get_auto_device_map_params, get_model_dtype, get_tgt_non_granular_index, paged_attn_supported,
9    parse_isq_value, AutoDeviceMapParams, BertEmbeddingModel, DefaultSchedulerMethod,
10    DeviceLayerMapMetadata, DeviceMapMetadata, DeviceMapSetting, Loader, LoaderBuilder,
11    MemoryGpuConfig, MistralRsBuilder, ModelSelected, PagedAttentionConfig, SchedulerConfig,
12    SearchCallback, TokenSource,
13};
14use tracing::info;
15
16use crate::types::{LoadedPipeline, SharedMistralRsState};
17
18pub mod defaults {
19    //! Provides the default values used for the mistral.rs instance for server.
20    //! These defaults can be used for CLI argument fallbacks, config loading, or general initialization.
21
22    use std::sync::Arc;
23
24    pub const DEVICE: Option<candle_core::Device> = None;
25    pub const SEED: Option<u64> = None;
26    pub const LOG: Option<String> = None;
27    pub const TRUNCATE_SEQUENCE: bool = false;
28    pub const MODEL: Option<mistralrs_core::ModelSelected> = None;
29    pub const MAX_SEQS: usize = 16;
30    pub const NO_KV_CACHE: bool = false;
31    pub const CHAT_TEMPLATE: Option<String> = None;
32    pub const JINJA_EXPLICIT: Option<String> = None;
33    pub const INTERACTIVE_MODE: bool = false;
34    pub const PREFIX_CACHE_N: usize = 16;
35    pub const NUM_DEVICE_LAYERS: Option<Vec<String>> = None;
36    pub const IN_SITU_QUANT: Option<String> = None;
37    pub const PAGED_ATTN_GPU_MEM: Option<usize> = None;
38    pub const PAGED_ATTN_GPU_MEM_USAGE: Option<f32> = None;
39    pub const PAGED_CTXT_LEN: Option<usize> = None;
40    pub const PAGED_ATTN_BLOCK_SIZE: Option<usize> = None;
41    pub const NO_PAGED_ATTN: bool = false;
42    pub const PAGED_ATTN: bool = false;
43    pub const PROMPT_CHUNKSIZE: Option<usize> = None;
44    pub const CPU: bool = false;
45    pub const ENABLE_SEARCH: bool = false;
46    pub const SEARCH_BERT_MODEL: Option<String> = None;
47    pub const TOKEN_SOURCE: mistralrs_core::TokenSource = mistralrs_core::TokenSource::CacheToken;
48    pub const SEARCH_CALLBACK: Option<Arc<mistralrs_core::SearchCallback>> = None;
49}
50
51/// A builder for creating a mistral.rs instance with configured options for the mistral.rs server.
52///
53/// ### Examples
54///
55/// Basic usage:
56/// ```ignore
57/// use mistralrs_server_core::mistralrs_for_server_builder::MistralRsForServerBuilder;
58///
59/// let args = Args::parse();
60///
61/// let mistralrs = MistralRsForServerBuilder::new()
62///        .with_truncate_sequence(args.truncate_sequence)
63///        .with_model(args.model)
64///        .with_max_seqs(args.max_seqs)
65///        .with_no_kv_cache(args.no_kv_cache)
66///        .with_token_source(args.token_source)
67///        .with_interactive_mode(args.interactive_mode)
68///        .with_prefix_cache_n(args.prefix_cache_n)
69///        .with_no_paged_attn(args.no_paged_attn)
70///        .with_paged_attn(args.paged_attn)
71///        .with_cpu(args.cpu)
72///        .with_enable_search(args.enable_search)
73///        .with_seed_optional(args.seed)
74///        .with_log_optional(args.log)
75///        .with_chat_template_optional(args.chat_template)
76///        .with_jinja_explicit_optional(args.jinja_explicit)
77///        .with_num_device_layers_optional(args.num_device_layers)
78///        .with_in_situ_quant_optional(args.in_situ_quant)
79///        .with_paged_attn_gpu_mem_optional(args.paged_attn_gpu_mem)
80///        .with_paged_attn_gpu_mem_usage_optional(args.paged_attn_gpu_mem_usage)
81///        .with_paged_ctxt_len_optional(args.paged_ctxt_len)
82///        .with_paged_attn_block_size_optional(args.paged_attn_block_size)
83///        .with_prompt_chunksize_optional(args.prompt_chunksize)
84///        .build()
85///        .await?;
86/// ```
87pub struct MistralRsForServerBuilder {
88    /// The Candle device to use for model execution (CPU, CUDA, Metal, etc.).
89    device: Option<Device>,
90
91    /// Integer seed to ensure reproducible random number generation.
92    seed: Option<u64>,
93
94    /// Log all responses and requests to this file
95    log: Option<String>,
96
97    /// If a sequence is larger than the maximum model length, truncate the number
98    /// of tokens such that the sequence will fit at most the maximum length.
99    /// If `max_tokens` is not specified in the request, space for 10 tokens will be reserved instead.
100    truncate_sequence: bool,
101
102    /// Model selector
103    model: Option<ModelSelected>,
104
105    /// Maximum running sequences at any time. If the `tgt_non_granular_index` flag is set for X-LoRA models, this will be set to 1.
106    max_seqs: usize,
107
108    /// Use no KV cache.
109    no_kv_cache: bool,
110
111    /// Chat template file with a JINJA file with `messages`, `add_generation_prompt`, `bos_token`, `eos_token`, and `unk_token` as inputs.
112    /// Used if the automatic deserialization fails. If this ends with `.json` (ie., it is a file) then that template is loaded.
113    chat_template: Option<String>,
114
115    /// Explicit JINJA chat template file (.jinja) to be used. If specified, this overrides all other chat templates.
116    jinja_explicit: Option<String>,
117
118    /// Source of the token for authentication.
119    /// Can be in the formats: `literal:<value>`, `env:<value>`, `path:<value>`, `cache` to use a cached token, or `none` to use no token.
120    /// Defaults to `cache`.
121    token_source: TokenSource,
122
123    /// Enter interactive mode instead of serving a chat server.
124    interactive_mode: bool,
125
126    /// Number of prefix caches to hold on the device. Other caches are evicted to the CPU based on a LRU strategy.
127    prefix_cache_n: usize,
128
129    /// NOTE: This can be omitted to use automatic device mapping!
130    /// Number of device layers to load and run on GPU(s). All others will be on the CPU.
131    /// If one GPU is used, then this value should be an integer. Otherwise, it follows the following pattern:
132    /// ORD:NUM;... Where ORD is a unique device ordinal and NUM is the number of layers for that device.
133    num_device_layers: Option<Vec<String>>,
134
135    /// In-situ quantization to apply.
136    in_situ_quant: Option<String>,
137
138    /// GPU memory to allocate for KV cache with PagedAttention in MBs.
139    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
140    /// The priority is as follows: `pa-ctxt-len` > `pa-gpu-mem-usage` > `pa-gpu-mem`.
141    paged_attn_gpu_mem: Option<usize>,
142
143    /// Percentage of GPU memory to utilize after allocation of KV cache with PagedAttention, from 0 to 1.
144    /// If this is not set and the device is CUDA, it will default to `0.9`.
145    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
146    /// The priority is as follows: `pa-ctxt-len` > `pa-gpu-mem-usage` > `pa-gpu-mem`.
147    paged_attn_gpu_mem_usage: Option<f32>,
148
149    /// Total context length to allocate the KV cache for (total number of tokens which the KV cache can hold).
150    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
151    /// The priority is as follows: `pa-ctxt-len` > `pa-gpu-mem-usage` > `pa-gpu-mem`.
152    /// This is the default setting, and it defaults to the `max-seq-len` specified in after the model type.
153    paged_ctxt_len: Option<usize>,
154
155    /// Block size (number of tokens per block) for PagedAttention. If this is not set and the device is CUDA, it will default to 32.
156    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
157    paged_attn_block_size: Option<usize>,
158
159    /// Disable PagedAttention on CUDA. Because PagedAttention is already disabled on Metal, this is only applicable on CUDA.
160    no_paged_attn: bool,
161
162    /// Enable PagedAttention on Metal. Because PagedAttention is already enabled on CUDA, this is only applicable on Metal.
163    paged_attn: bool,
164
165    /// Number of tokens to batch the prompt step into. This can help with OOM errors when in the prompt step, but reduces performance.
166    prompt_chunksize: Option<usize>,
167
168    /// Use CPU only
169    cpu: bool,
170
171    /// Enable searching compatible with the OpenAI `web_search_options` setting. This uses the BERT model specified below or the default.
172    enable_search: bool,
173
174    /// Specify a Hugging Face model ID for a BERT model to assist web searching. Defaults to Snowflake Arctic Embed L.
175    search_bert_model: Option<String>,
176
177    /// Optional override search callback
178    search_callback: Option<Arc<SearchCallback>>,
179}
180
181impl Default for MistralRsForServerBuilder {
182    /// Creates a new builder with default configuration.
183    fn default() -> Self {
184        Self {
185            device: defaults::DEVICE,
186            seed: defaults::SEED,
187            log: defaults::LOG,
188            truncate_sequence: defaults::TRUNCATE_SEQUENCE,
189            model: defaults::MODEL,
190            max_seqs: defaults::MAX_SEQS,
191            no_kv_cache: defaults::NO_KV_CACHE,
192            chat_template: defaults::CHAT_TEMPLATE,
193            jinja_explicit: defaults::JINJA_EXPLICIT,
194            token_source: defaults::TOKEN_SOURCE,
195            interactive_mode: defaults::INTERACTIVE_MODE,
196            prefix_cache_n: defaults::PREFIX_CACHE_N,
197            num_device_layers: defaults::NUM_DEVICE_LAYERS,
198            in_situ_quant: defaults::IN_SITU_QUANT,
199            paged_attn_gpu_mem: defaults::PAGED_ATTN_GPU_MEM,
200            paged_attn_gpu_mem_usage: defaults::PAGED_ATTN_GPU_MEM_USAGE,
201            paged_ctxt_len: defaults::PAGED_CTXT_LEN,
202            paged_attn_block_size: defaults::PAGED_ATTN_BLOCK_SIZE,
203            no_paged_attn: defaults::NO_PAGED_ATTN,
204            paged_attn: defaults::PAGED_ATTN,
205            prompt_chunksize: defaults::PROMPT_CHUNKSIZE,
206            cpu: defaults::CPU,
207            enable_search: defaults::ENABLE_SEARCH,
208            search_bert_model: defaults::SEARCH_BERT_MODEL,
209            search_callback: defaults::SEARCH_CALLBACK,
210        }
211    }
212}
213
214impl MistralRsForServerBuilder {
215    /// Creates a new `MistralRsForServerBuilder` with default settings.
216    ///
217    /// This is equivalent to calling `Default::default()`.
218    ///
219    /// ### Examples
220    ///
221    /// ```ignore
222    /// use mistralrs_server_core::mistralrs_for_server_builder::MistralRsForServerBuilder;
223    ///
224    /// let builder = mistralrs_server_core::mistralrs_for_server_builder::MistralRsForServerBuilder::new();
225    /// ```
226    pub fn new() -> Self {
227        Default::default()
228    }
229
230    /// Sets the Candle device to use for model execution.
231    pub fn with_device(mut self, device: Device) -> Self {
232        self.device = Some(device);
233        self
234    }
235
236    /// Sets the random seed for deterministic model behavior.
237    pub fn with_seed(mut self, seed: u64) -> Self {
238        self.seed = Some(seed);
239        self
240    }
241
242    /// Sets the random seed if provided.
243    pub fn with_seed_optional(mut self, seed: Option<u64>) -> Self {
244        if let Some(seed) = seed {
245            self = self.with_seed(seed);
246        }
247        self
248    }
249
250    /// Sets the logging configuration.
251    pub fn with_log(mut self, log: String) -> Self {
252        self.log = Some(log);
253        self
254    }
255
256    /// Sets the logging configuration if provided.
257    pub fn with_log_optional(mut self, log: Option<String>) -> Self {
258        if let Some(log) = log {
259            self = self.with_log(log);
260        }
261        self
262    }
263
264    /// Sets whether to truncate sequences that exceed the maximum model length.
265    pub fn with_truncate_sequence(mut self, truncate_sequence: bool) -> Self {
266        self.truncate_sequence = truncate_sequence;
267        self
268    }
269
270    /// Sets the model to be used.
271    pub fn with_model(mut self, model: ModelSelected) -> Self {
272        self.model = Some(model);
273        self
274    }
275
276    /// Sets the maximum number of concurrent sequences.
277    pub fn with_max_seqs(mut self, max_seqs: usize) -> Self {
278        self.max_seqs = max_seqs;
279        self
280    }
281
282    /// Sets whether to disable the key-value cache.
283    pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
284        self.no_kv_cache = no_kv_cache;
285        self
286    }
287
288    /// Sets the chat template configuration.
289    pub fn with_chat_template(mut self, chat_template: String) -> Self {
290        self.chat_template = Some(chat_template);
291        self
292    }
293
294    /// Sets the chat template configuration if provided.
295    pub fn with_chat_template_optional(mut self, chat_template: Option<String>) -> Self {
296        if let Some(chat_template) = chat_template {
297            self = self.with_chat_template(chat_template);
298        }
299        self
300    }
301
302    /// Sets an explicit JINJA chat template file.
303    pub fn with_jinja_explicit(mut self, jinja_explicit: String) -> Self {
304        self.jinja_explicit = Some(jinja_explicit);
305        self
306    }
307
308    /// Sets an explicit JINJA chat template file if provided.
309    pub fn with_jinja_explicit_optional(mut self, jinja_explicit: Option<String>) -> Self {
310        if let Some(jinja_explicit) = jinja_explicit {
311            self = self.with_jinja_explicit(jinja_explicit);
312        }
313        self
314    }
315
316    /// Sets the token source for authentication.
317    pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
318        self.token_source = token_source;
319        self
320    }
321
322    /// Sets whether to run in interactive mode.
323    pub fn with_interactive_mode(mut self, interactive_mode: bool) -> Self {
324        self.interactive_mode = interactive_mode;
325        self
326    }
327
328    /// Sets the number of prefix caches to hold on the device.
329    pub fn with_prefix_cache_n(mut self, prefix_cache_n: usize) -> Self {
330        self.prefix_cache_n = prefix_cache_n;
331        self
332    }
333
334    /// Sets the device layer mapping
335    pub fn with_num_device_layers(mut self, num_device_layers: Vec<String>) -> Self {
336        self.num_device_layers = Some(num_device_layers);
337        self
338    }
339
340    /// Sets the device layer mapping if provided.
341    pub fn with_num_device_layers_optional(
342        mut self,
343        num_device_layers: Option<Vec<String>>,
344    ) -> Self {
345        if let Some(num_device_layers) = num_device_layers {
346            self = self.with_num_device_layers(num_device_layers);
347        }
348        self
349    }
350
351    /// Sets the in-situ quantization method.
352    pub fn with_in_situ_quant(mut self, in_situ_quant: String) -> Self {
353        self.in_situ_quant = Some(in_situ_quant);
354        self
355    }
356
357    /// Sets the in-situ quantization method if provided.
358    pub fn with_in_situ_quant_optional(mut self, in_situ_quant: Option<String>) -> Self {
359        if let Some(in_situ_quant) = in_situ_quant {
360            self = self.with_in_situ_quant(in_situ_quant);
361        }
362        self
363    }
364
365    /// Sets the GPU memory allocation for PagedAttention KV cache.
366    pub fn with_paged_attn_gpu_mem(mut self, paged_attn_gpu_mem: usize) -> Self {
367        self.paged_attn_gpu_mem = Some(paged_attn_gpu_mem);
368        self
369    }
370
371    /// Sets the GPU memory allocation for PagedAttention KV cache if provided.
372    pub fn with_paged_attn_gpu_mem_optional(mut self, paged_attn_gpu_mem: Option<usize>) -> Self {
373        if let Some(paged_attn_gpu_mem) = paged_attn_gpu_mem {
374            self = self.with_paged_attn_gpu_mem(paged_attn_gpu_mem);
375        }
376        self
377    }
378
379    /// Sets the percentage of GPU memory to utilize for PagedAttention.
380    pub fn with_paged_attn_gpu_mem_usage(mut self, paged_attn_gpu_mem_usage: f32) -> Self {
381        self.paged_attn_gpu_mem_usage = Some(paged_attn_gpu_mem_usage);
382        self
383    }
384
385    /// Sets the percentage of GPU memory to utilize for PagedAttention if provided.
386    pub fn with_paged_attn_gpu_mem_usage_optional(
387        mut self,
388        paged_attn_gpu_mem_usage: Option<f32>,
389    ) -> Self {
390        if let Some(paged_attn_gpu_mem_usage) = paged_attn_gpu_mem_usage {
391            self = self.with_paged_attn_gpu_mem_usage(paged_attn_gpu_mem_usage);
392        }
393        self
394    }
395
396    /// Sets the total context length for KV cache allocation.
397    pub fn with_paged_ctxt_len(mut self, paged_ctxt_len: usize) -> Self {
398        self.paged_ctxt_len = Some(paged_ctxt_len);
399        self
400    }
401
402    /// Sets the total context length for KV cache allocation if provided.
403    pub fn with_paged_ctxt_len_optional(mut self, paged_ctxt_len: Option<usize>) -> Self {
404        if let Some(paged_ctxt_len) = paged_ctxt_len {
405            self = self.with_paged_ctxt_len(paged_ctxt_len);
406        }
407        self
408    }
409
410    /// Sets the block size for PagedAttention.
411    pub fn with_paged_attn_block_size(mut self, paged_attn_block_size: usize) -> Self {
412        self.paged_attn_block_size = Some(paged_attn_block_size);
413        self
414    }
415
416    /// Sets the block size for PagedAttention if provided.
417    pub fn with_paged_attn_block_size_optional(
418        mut self,
419        paged_attn_block_size: Option<usize>,
420    ) -> Self {
421        if let Some(paged_attn_block_size) = paged_attn_block_size {
422            self = self.with_paged_attn_block_size(paged_attn_block_size);
423        }
424        self
425    }
426
427    /// Sets whether to disable PagedAttention on CUDA devices.
428    pub fn with_no_paged_attn(mut self, no_paged_attn: bool) -> Self {
429        self.no_paged_attn = no_paged_attn;
430        self
431    }
432
433    /// Sets whether to enable PagedAttention.
434    pub fn with_paged_attn(mut self, paged_attn: bool) -> Self {
435        self.paged_attn = paged_attn;
436        self
437    }
438
439    /// Sets the prompt chunking size.
440    pub fn with_prompt_chunksize(mut self, prompt_chunksize: usize) -> Self {
441        self.prompt_chunksize = Some(prompt_chunksize);
442        self
443    }
444
445    /// Sets the prompt chunking size if provided.
446    pub fn with_prompt_chunksize_optional(mut self, prompt_chunksize: Option<usize>) -> Self {
447        if let Some(prompt_chunksize) = prompt_chunksize {
448            self = self.with_prompt_chunksize(prompt_chunksize);
449        }
450        self
451    }
452
453    /// Sets whether to force CPU-only execution.
454    pub fn with_cpu(mut self, cpu: bool) -> Self {
455        self.cpu = cpu;
456        self
457    }
458
459    /// Sets whether to enable web search functionality.
460    pub fn with_enable_search(mut self, enable_search: bool) -> Self {
461        self.enable_search = enable_search;
462        self
463    }
464
465    /// Sets the BERT model for web search assistance.
466    pub fn with_search_bert_model(mut self, search_bert_model: String) -> Self {
467        self.search_bert_model = Some(search_bert_model);
468        self
469    }
470
471    /// Override the search function used when `web_search_options` is enabled.
472    pub fn with_search_callback(mut self, callback: Arc<SearchCallback>) -> Self {
473        self.search_callback = Some(callback);
474        self
475    }
476
477    /// Builds the configured mistral.rs instance.
478    ///
479    /// ### Examples
480    ///
481    /// ```ignore
482    /// use mistralrs_server_core::mistralrs_for_server_builder::MistralRsForServerBuilder;
483    ///
484    /// let shared_mistralrs = MistralRsForServerBuilder::new()
485    ///     .with_model(model)
486    ///     .with_in_situ_quant("8".to_string())
487    ///     .with_paged_attn(true)
488    ///     .build()
489    ///     .await?;
490    /// ```
491    pub async fn build(mut self) -> Result<SharedMistralRsState> {
492        // This was originally with the device config
493        if self.cpu {
494            self.no_paged_attn = true;
495        }
496
497        let model = self.model.context("Model was None")?;
498
499        let tgt_non_granular_index = get_tgt_non_granular_index(&model);
500        let dtype = get_model_dtype(&model)?;
501        let auto_device_map_params = get_auto_device_map_params(&model)?;
502
503        if tgt_non_granular_index.is_some() {
504            self.max_seqs = 1;
505        }
506
507        let prompt_chunksize = match self.prompt_chunksize {
508            Some(0) => {
509                anyhow::bail!("`prompt_chunksize` must be a strictly positive integer, got 0.",)
510            }
511            Some(x) => Some(NonZeroUsize::new(x).unwrap()),
512            None => None,
513        };
514
515        let max_seq_len = auto_device_map_params.max_seq_len();
516
517        let device = if let Some(device) = self.device {
518            device
519        } else {
520            init_device(self.cpu, self.seed)?
521        };
522
523        let mapper = init_mapper(&self.num_device_layers, &auto_device_map_params);
524        let no_paged_attn = configure_no_paged_attn(&device, self.no_paged_attn, self.paged_attn);
525
526        // Allocate 0.5 GB of CPU memory just as a placeholder.
527        // Nothing happens here as we have no `swap_out`, see `_preempt_by_swap`.
528        let cache_config = init_cache_config(
529            self.paged_attn_block_size,
530            self.paged_attn_gpu_mem,
531            self.paged_attn_gpu_mem_usage,
532            self.paged_ctxt_len,
533            no_paged_attn,
534            max_seq_len,
535        )?;
536
537        // Configure this last to prevent arg moves
538        let loader: Box<dyn Loader> = LoaderBuilder::new(model)
539            .with_no_kv_cache(self.no_kv_cache)
540            .with_chat_template(self.chat_template)
541            .with_prompt_chunksize(prompt_chunksize)
542            .with_jinja_explicit(self.jinja_explicit)
543            .build()?;
544
545        mistralrs_instance_info(&*loader);
546
547        let isq = self
548            .in_situ_quant
549            .as_ref()
550            .and_then(|isq| parse_isq_value(isq, Some(&device)).ok());
551
552        let pipeline: LoadedPipeline = loader.load_model_from_hf(
553            None,
554            self.token_source,
555            &dtype,
556            &device,
557            false,
558            mapper,
559            isq,
560            cache_config,
561        )?;
562        info!("Model loaded.");
563
564        let scheduler_config = init_scheduler_config(&cache_config, &pipeline, self.max_seqs).await;
565
566        let bert_model = get_bert_model(self.enable_search, self.search_bert_model);
567
568        let mistralrs = MistralRsBuilder::new(
569            pipeline,
570            scheduler_config,
571            !self.interactive_mode,
572            bert_model,
573        )
574        .with_opt_log(self.log)
575        .with_truncate_sequence(self.truncate_sequence)
576        .with_no_kv_cache(self.no_kv_cache)
577        .with_prefix_cache_n(self.prefix_cache_n)
578        .build();
579
580        Ok(mistralrs)
581    }
582}
583
584// TODO: replace with best device?
585/// Initializes the device to be used for computation, optionally forcing CPU usage and setting a seed.
586fn init_device(force_cpu: bool, seed: Option<u64>) -> Result<candle_core::Device> {
587    #[cfg(feature = "metal")]
588    let device = if force_cpu {
589        Device::Cpu
590    } else {
591        Device::new_metal(0)?
592    };
593    #[cfg(not(feature = "metal"))]
594    #[allow(clippy::if_same_then_else)]
595    let device = if force_cpu {
596        Device::Cpu
597    } else if mistralrs_core::distributed::use_nccl() {
598        Device::Cpu
599    } else {
600        Device::cuda_if_available(0)?
601    };
602
603    if let Some(seed) = seed {
604        device.set_seed(seed)?;
605    }
606
607    Ok(device)
608}
609
610/// Initializes the device mapping configuration for distributing model layers.
611fn init_mapper(
612    num_device_layers: &Option<Vec<String>>,
613    auto_device_map_params: &AutoDeviceMapParams,
614) -> DeviceMapSetting {
615    // Parse device mapper
616    if let Some(device_layers) = num_device_layers {
617        if device_layers.len() == 1 && device_layers[0].parse::<usize>().is_ok() {
618            let layers = device_layers[0].parse::<usize>().unwrap();
619            DeviceMapSetting::Map(DeviceMapMetadata::from_num_device_layers(vec![
620                DeviceLayerMapMetadata { ordinal: 0, layers },
621            ]))
622        } else {
623            let mut mapping = Vec::new();
624            for layer in device_layers {
625                let split = layer.splitn(2, ':').collect::<Vec<_>>();
626                if split.len() < 2 {
627                    panic!("Expected layer to be of format ORD:NUM, got {layer}");
628                }
629                let ord = split[0]
630                    .parse::<usize>()
631                    .unwrap_or_else(|_| panic!("Failed to parse {} as integer.", split[0]));
632                let num = split[1]
633                    .parse::<usize>()
634                    .unwrap_or_else(|_| panic!("Failed to parse {} as integer.", split[1]));
635                for DeviceLayerMapMetadata { ordinal, layers: _ } in &mapping {
636                    if *ordinal == ord {
637                        panic!("Duplicate ordinal {ord}");
638                    }
639                }
640                mapping.push(DeviceLayerMapMetadata {
641                    ordinal: ord,
642                    layers: num,
643                });
644            }
645            DeviceMapSetting::Map(DeviceMapMetadata::from_num_device_layers(mapping))
646        }
647    } else {
648        DeviceMapSetting::Auto(auto_device_map_params.clone())
649    }
650}
651
652/// Logs hardware feature information and the model's sampling strategy and kind.
653fn mistralrs_instance_info(loader: &dyn Loader) {
654    info!(
655        "avx: {}, neon: {}, simd128: {}, f16c: {}",
656        candle_core::utils::with_avx(),
657        candle_core::utils::with_neon(),
658        candle_core::utils::with_simd128(),
659        candle_core::utils::with_f16c()
660    );
661
662    info!("Sampling method: penalties -> temperature -> topk -> topp -> minp -> multinomial");
663    info!("Model kind is: {}", loader.get_kind().to_string());
664}
665
666/// Determines whether paged attention should be disabled based on device type and preferences.
667fn configure_no_paged_attn(device: &Device, no_paged_attn: bool, paged_attn: bool) -> bool {
668    if device.is_cuda() || mistralrs_core::distributed::use_nccl() {
669        no_paged_attn
670    } else if device.is_metal() {
671        !paged_attn
672    } else {
673        true
674    }
675}
676
677/// Initializes the cache configuration for paged attention based on provided parameters.
678fn init_cache_config(
679    paged_attn_block_size: Option<usize>,
680    paged_attn_gpu_mem: Option<usize>,
681    paged_attn_gpu_mem_usage: Option<f32>,
682    paged_ctxt_len: Option<usize>,
683    no_paged_attn: bool,
684    max_seq_len: usize,
685) -> Result<Option<PagedAttentionConfig>> {
686    match (
687        paged_attn_block_size,
688        paged_attn_gpu_mem,
689        paged_attn_gpu_mem_usage,
690        paged_ctxt_len,
691        paged_attn_supported(),
692        no_paged_attn,
693    ) {
694        (block_size, None, None, None, true, false) => Ok(Some(PagedAttentionConfig::new(
695            block_size,
696            512,
697            MemoryGpuConfig::ContextSize(max_seq_len),
698        )?)),
699        (block_size, None, None, Some(ctxt), true, false) => Ok(Some(PagedAttentionConfig::new(
700            block_size,
701            512,
702            MemoryGpuConfig::ContextSize(ctxt),
703        )?)),
704        (block_size, None, Some(f), None, true, false) => Ok(Some(PagedAttentionConfig::new(
705            block_size,
706            512,
707            MemoryGpuConfig::Utilization(f),
708        )?)),
709        (block_size, Some(m), None, None, true, false) => Ok(Some(PagedAttentionConfig::new(
710            block_size,
711            512,
712            MemoryGpuConfig::MbAmount(m),
713        )?)),
714        (block_size, Some(_m), Some(f), None, true, false) => {
715            info!("Both memory size, and usage were specified, defaulting to the usage value.");
716            Ok(Some(PagedAttentionConfig::new(
717                block_size,
718                512,
719                MemoryGpuConfig::Utilization(f),
720            )?))
721        }
722        (block_size, Some(_m), None, Some(ctxt), true, false) => {
723            info!("All memory size and ctxt len, defaulting to the context len value.");
724            Ok(Some(PagedAttentionConfig::new(
725                block_size,
726                512,
727                MemoryGpuConfig::ContextSize(ctxt),
728            )?))
729        }
730        (block_size, None, Some(f), Some(_ctxt), true, false) => {
731            info!("Both ctxt len and usage were specified, defaulting to the usage value.");
732            Ok(Some(PagedAttentionConfig::new(
733                block_size,
734                512,
735                MemoryGpuConfig::Utilization(f),
736            )?))
737        }
738        (_, _, _, _, _, _) => Ok(None),
739    }
740}
741
742/// Initializes the scheduler configuration based on cache settings and pipeline metadata.
743async fn init_scheduler_config(
744    cache_config: &Option<PagedAttentionConfig>,
745    pipeline: &LoadedPipeline,
746    args_max_seqs: usize,
747) -> SchedulerConfig {
748    if cache_config.is_some() {
749        // Handle case where we may have device mapping
750        if let Some(ref cache_config) = pipeline.lock().await.get_metadata().cache_config {
751            SchedulerConfig::PagedAttentionMeta {
752                max_num_seqs: args_max_seqs,
753                config: cache_config.clone(),
754            }
755        } else {
756            SchedulerConfig::DefaultScheduler {
757                method: DefaultSchedulerMethod::Fixed(args_max_seqs.try_into().unwrap()),
758            }
759        }
760    } else {
761        SchedulerConfig::DefaultScheduler {
762            method: DefaultSchedulerMethod::Fixed(args_max_seqs.try_into().unwrap()),
763        }
764    }
765}
766
767/// Creates a BERT embedding model configuration for search functionality.
768pub fn get_bert_model(
769    enable_search: bool,
770    search_bert_model: Option<String>,
771) -> Option<BertEmbeddingModel> {
772    if enable_search {
773        Some(
774            search_bert_model
775                .map(BertEmbeddingModel::Custom)
776                .unwrap_or_default(),
777        )
778    } else {
779        None
780    }
781}