mistralrs/
model_builder_trait.rs

1use mistralrs_core::{AddModelConfig, Pipeline, SchedulerConfig};
2use std::sync::Arc;
3use tokio::sync::Mutex;
4
5use crate::Model;
6
7/// Enum representing all possible model builders that can be used with MultiModelBuilder.
8pub enum AnyModelBuilder {
9    Text(crate::TextModelBuilder),
10    Vision(crate::VisionModelBuilder),
11    Gguf(crate::GgufModelBuilder),
12    Diffusion(crate::DiffusionModelBuilder),
13    Speech(crate::SpeechModelBuilder),
14    Embedding(crate::EmbeddingModelBuilder),
15}
16
17impl AnyModelBuilder {
18    /// Get the default model ID for this builder.
19    pub fn model_id(&self) -> String {
20        match self {
21            AnyModelBuilder::Text(b) => b.model_id.clone(),
22            AnyModelBuilder::Vision(b) => b.model_id.clone(),
23            AnyModelBuilder::Gguf(b) => b.model_id.clone(),
24            AnyModelBuilder::Diffusion(b) => b.model_id.clone(),
25            AnyModelBuilder::Speech(b) => b.model_id.clone(),
26            AnyModelBuilder::Embedding(b) => b.model_id.clone(),
27        }
28    }
29
30    /// Build the pipeline and configuration for this model.
31    pub async fn build_pipeline(
32        self,
33    ) -> anyhow::Result<(Arc<Mutex<dyn Pipeline>>, SchedulerConfig, AddModelConfig)> {
34        match self {
35            AnyModelBuilder::Text(b) => build_text_pipeline(b).await,
36            AnyModelBuilder::Vision(b) => build_vision_pipeline(b).await,
37            AnyModelBuilder::Gguf(b) => build_gguf_pipeline(b).await,
38            AnyModelBuilder::Diffusion(b) => build_diffusion_pipeline(b).await,
39            AnyModelBuilder::Speech(b) => build_speech_pipeline(b).await,
40            AnyModelBuilder::Embedding(b) => build_embedding_pipeline(b).await,
41        }
42    }
43}
44
45// Conversion implementations
46impl From<crate::TextModelBuilder> for AnyModelBuilder {
47    fn from(b: crate::TextModelBuilder) -> Self {
48        AnyModelBuilder::Text(b)
49    }
50}
51
52impl From<crate::VisionModelBuilder> for AnyModelBuilder {
53    fn from(b: crate::VisionModelBuilder) -> Self {
54        AnyModelBuilder::Vision(b)
55    }
56}
57
58impl From<crate::GgufModelBuilder> for AnyModelBuilder {
59    fn from(b: crate::GgufModelBuilder) -> Self {
60        AnyModelBuilder::Gguf(b)
61    }
62}
63
64impl From<crate::DiffusionModelBuilder> for AnyModelBuilder {
65    fn from(b: crate::DiffusionModelBuilder) -> Self {
66        AnyModelBuilder::Diffusion(b)
67    }
68}
69
70impl From<crate::SpeechModelBuilder> for AnyModelBuilder {
71    fn from(b: crate::SpeechModelBuilder) -> Self {
72        AnyModelBuilder::Speech(b)
73    }
74}
75
76impl From<crate::EmbeddingModelBuilder> for AnyModelBuilder {
77    fn from(b: crate::EmbeddingModelBuilder) -> Self {
78        AnyModelBuilder::Embedding(b)
79    }
80}
81
82/// Builder for creating a Model with multiple models.
83pub struct MultiModelBuilder {
84    builders: Vec<AnyModelBuilder>,
85    default_model_id: Option<String>,
86}
87
88impl Default for MultiModelBuilder {
89    fn default() -> Self {
90        Self::new()
91    }
92}
93
94impl MultiModelBuilder {
95    /// Create a new MultiModelBuilder.
96    pub fn new() -> Self {
97        Self {
98            builders: Vec::new(),
99            default_model_id: None,
100        }
101    }
102
103    /// Add a model. The model ID will be the pipeline's model_id (e.g., "google/gemma-3-4b-it").
104    pub fn add_model<B: Into<AnyModelBuilder>>(mut self, builder: B) -> Self {
105        self.builders.push(builder.into());
106        self
107    }
108
109    /// Set the default model by its model ID (e.g., "google/gemma-3-4b-it").
110    pub fn with_default_model(mut self, model_id: impl ToString) -> Self {
111        self.default_model_id = Some(model_id.to_string());
112        self
113    }
114
115    /// Build the multi-model Model instance.
116    pub async fn build(self) -> anyhow::Result<Model> {
117        if self.builders.is_empty() {
118            anyhow::bail!("MultiModelBuilder requires at least one model to be added");
119        }
120
121        // Build the first model to create the initial MistralRs instance
122        let mut builders_iter = self.builders.into_iter();
123        let first_builder = builders_iter.next().unwrap();
124
125        let (pipeline, scheduler_config, add_model_config) = first_builder.build_pipeline().await?;
126
127        // Create the MistralRsBuilder for the first model
128        // The model ID will be taken from the pipeline's name
129        let mut runner_builder = mistralrs_core::MistralRsBuilder::new(
130            pipeline,
131            scheduler_config,
132            add_model_config.engine_config.throughput_logging_enabled,
133            add_model_config.engine_config.search_embedding_model,
134        );
135
136        if let Some(cb) = add_model_config.engine_config.search_callback.clone() {
137            runner_builder = runner_builder.with_search_callback(cb);
138        }
139
140        for (name, cb) in &add_model_config.engine_config.tool_callbacks {
141            runner_builder = runner_builder.with_tool_callback(name.clone(), cb.clone());
142        }
143
144        for (name, callback_with_tool) in &add_model_config.engine_config.tool_callbacks_with_tools
145        {
146            runner_builder = runner_builder.with_tool_callback_and_tool(
147                name.clone(),
148                callback_with_tool.callback.clone(),
149                callback_with_tool.tool.clone(),
150            );
151        }
152
153        if let Some(mcp_config) = add_model_config.mcp_client_config.clone() {
154            runner_builder = runner_builder.with_mcp_client(mcp_config);
155        }
156
157        if let Some(loader_config) = add_model_config.loader_config.clone() {
158            runner_builder = runner_builder.with_loader_config(loader_config);
159        }
160
161        runner_builder = runner_builder
162            .with_no_kv_cache(add_model_config.engine_config.no_kv_cache)
163            .with_no_prefix_cache(add_model_config.engine_config.no_prefix_cache)
164            .with_prefix_cache_n(add_model_config.engine_config.prefix_cache_n);
165
166        let mistralrs = runner_builder.build().await;
167
168        // Add remaining models using their pipeline names as IDs
169        for builder in builders_iter {
170            let model_id = builder.model_id();
171            let (pipeline, scheduler_config, add_model_config) = builder.build_pipeline().await?;
172            mistralrs
173                .add_model(model_id, pipeline, scheduler_config, add_model_config)
174                .await
175                .map_err(|e| anyhow::anyhow!(e))?;
176        }
177
178        // Set the default model if specified
179        if let Some(default_id) = self.default_model_id {
180            mistralrs
181                .set_default_model_id(&default_id)
182                .map_err(|e| anyhow::anyhow!(e))?;
183        }
184        // Otherwise, the first model is already the default (set by MistralRs::new)
185
186        Ok(Model::new(mistralrs))
187    }
188}
189
190// Pipeline building functions for each model type.
191// These are public so individual builders can reuse them to avoid code duplication.
192
193/// Create a Model from pipeline components.
194/// This is the common code path used by all individual builder `build()` methods.
195pub async fn build_model_from_pipeline(
196    pipeline: Arc<Mutex<dyn mistralrs_core::Pipeline>>,
197    scheduler_config: SchedulerConfig,
198    add_model_config: AddModelConfig,
199) -> Model {
200    let mut runner_builder = mistralrs_core::MistralRsBuilder::new(
201        pipeline,
202        scheduler_config,
203        add_model_config.engine_config.throughput_logging_enabled,
204        add_model_config.engine_config.search_embedding_model,
205    );
206
207    if let Some(cb) = add_model_config.engine_config.search_callback.clone() {
208        runner_builder = runner_builder.with_search_callback(cb);
209    }
210
211    for (name, cb) in &add_model_config.engine_config.tool_callbacks {
212        runner_builder = runner_builder.with_tool_callback(name.clone(), cb.clone());
213    }
214
215    for (name, callback_with_tool) in &add_model_config.engine_config.tool_callbacks_with_tools {
216        runner_builder = runner_builder.with_tool_callback_and_tool(
217            name.clone(),
218            callback_with_tool.callback.clone(),
219            callback_with_tool.tool.clone(),
220        );
221    }
222
223    if let Some(mcp_config) = add_model_config.mcp_client_config.clone() {
224        runner_builder = runner_builder.with_mcp_client(mcp_config);
225    }
226
227    if let Some(loader_config) = add_model_config.loader_config.clone() {
228        runner_builder = runner_builder.with_loader_config(loader_config);
229    }
230
231    runner_builder = runner_builder
232        .with_no_kv_cache(add_model_config.engine_config.no_kv_cache)
233        .with_no_prefix_cache(add_model_config.engine_config.no_prefix_cache)
234        .with_prefix_cache_n(add_model_config.engine_config.prefix_cache_n);
235
236    Model::new(runner_builder.build().await)
237}
238
239/// Build a text model pipeline from a TextModelBuilder.
240/// Returns the pipeline, scheduler config, and AddModelConfig needed for Model creation.
241pub async fn build_text_pipeline(
242    builder: crate::TextModelBuilder,
243) -> anyhow::Result<(Arc<Mutex<dyn Pipeline>>, SchedulerConfig, AddModelConfig)> {
244    use crate::best_device;
245    use mistralrs_core::*;
246
247    let config = NormalSpecificConfig {
248        topology: builder.topology.clone(),
249        organization: builder.organization,
250        write_uqff: builder.write_uqff.clone(),
251        from_uqff: builder.from_uqff.clone(),
252        imatrix: builder.imatrix.clone(),
253        calibration_file: builder.calibration_file.clone(),
254        hf_cache_path: builder.hf_cache_path.clone(),
255        matformer_config_path: builder.matformer_config_path.clone(),
256        matformer_slice_name: builder.matformer_slice_name.clone(),
257    };
258
259    if builder.with_logging {
260        initialize_logging();
261    }
262
263    let loader = NormalLoaderBuilder::new(
264        config,
265        builder.chat_template.clone(),
266        builder.tokenizer_json.clone(),
267        Some(builder.model_id.clone()),
268        builder.no_kv_cache,
269        builder.jinja_explicit.clone(),
270    )
271    .build(builder.loader_type.clone())?;
272
273    let pipeline = loader.load_model_from_hf(
274        builder.hf_revision.clone(),
275        builder.token_source.clone(),
276        &builder.dtype,
277        &builder
278            .device
279            .clone()
280            .unwrap_or(best_device(builder.force_cpu).unwrap()),
281        !builder.with_logging,
282        builder
283            .device_mapping
284            .clone()
285            .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
286        builder.isq,
287        builder.paged_attn_cfg,
288    )?;
289
290    let scheduler_config = match builder.paged_attn_cfg {
291        Some(_) => {
292            let config = pipeline
293                .lock()
294                .await
295                .get_metadata()
296                .cache_config
297                .as_ref()
298                .cloned();
299
300            if let Some(config) = config {
301                SchedulerConfig::PagedAttentionMeta {
302                    max_num_seqs: builder.max_num_seqs,
303                    config,
304                }
305            } else {
306                SchedulerConfig::DefaultScheduler {
307                    method: DefaultSchedulerMethod::Fixed(builder.max_num_seqs.try_into()?),
308                }
309            }
310        }
311        None => SchedulerConfig::DefaultScheduler {
312            method: DefaultSchedulerMethod::Fixed(builder.max_num_seqs.try_into()?),
313        },
314    };
315
316    let engine_config = EngineConfig {
317        throughput_logging_enabled: builder.throughput_logging,
318        search_embedding_model: builder.search_embedding_model,
319        search_callback: builder.search_callback.clone(),
320        tool_callbacks: builder.tool_callbacks.clone(),
321        tool_callbacks_with_tools: builder
322            .tool_callbacks_with_tools
323            .iter()
324            .map(|(k, v)| {
325                (
326                    k.clone(),
327                    mistralrs_core::ToolCallbackWithTool {
328                        callback: v.callback.clone(),
329                        tool: v.tool.clone(),
330                    },
331                )
332            })
333            .collect(),
334        no_kv_cache: builder.no_kv_cache,
335        no_prefix_cache: builder.prefix_cache_n.is_none(),
336        prefix_cache_n: builder.prefix_cache_n.unwrap_or(16),
337        disable_eos_stop: false,
338    };
339
340    // Create loader config for unload/reload support
341    let device = builder
342        .device
343        .clone()
344        .unwrap_or(best_device(builder.force_cpu).unwrap());
345    let device_map_setting = builder
346        .device_mapping
347        .clone()
348        .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text()));
349
350    // Convert from_uqff Vec<PathBuf> to semicolon-separated string if present
351    let from_uqff_str = builder.from_uqff.as_ref().map(|paths| {
352        paths
353            .iter()
354            .map(|p| p.to_string_lossy())
355            .collect::<Vec<_>>()
356            .join(";")
357    });
358
359    let loader_config = ModelLoaderConfig {
360        model_selected: ModelSelected::Plain {
361            model_id: builder.model_id.clone(),
362            tokenizer_json: builder.tokenizer_json.clone(),
363            arch: builder.loader_type,
364            dtype: builder.dtype,
365            topology: builder.topology_path.clone(),
366            organization: Some(builder.organization),
367            write_uqff: builder.write_uqff.clone(),
368            from_uqff: from_uqff_str,
369            imatrix: builder.imatrix.clone(),
370            calibration_file: builder.calibration_file.clone(),
371            max_seq_len: AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN,
372            max_batch_size: AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE,
373            hf_cache_path: builder.hf_cache_path.clone(),
374            matformer_config_path: builder.matformer_config_path.clone(),
375            matformer_slice_name: builder.matformer_slice_name.clone(),
376        },
377        token_source: builder.token_source.clone(),
378        hf_revision: builder.hf_revision.clone(),
379        dtype: builder.dtype,
380        device,
381        device_map_setting,
382        isq: builder.isq,
383        paged_attn_config: builder.paged_attn_cfg,
384        silent: !builder.with_logging,
385        chat_template: builder.chat_template.clone(),
386        jinja_explicit: builder.jinja_explicit.clone(),
387    };
388
389    let add_model_config = AddModelConfig {
390        engine_config,
391        mcp_client_config: builder.mcp_client_config.clone(),
392        loader_config: Some(loader_config),
393    };
394
395    Ok((pipeline, scheduler_config, add_model_config))
396}
397
398/// Build a vision model pipeline from a VisionModelBuilder.
399/// Returns the pipeline, scheduler config, and AddModelConfig needed for Model creation.
400pub async fn build_vision_pipeline(
401    builder: crate::VisionModelBuilder,
402) -> anyhow::Result<(Arc<Mutex<dyn Pipeline>>, SchedulerConfig, AddModelConfig)> {
403    use crate::best_device;
404    use mistralrs_core::*;
405
406    let config = VisionSpecificConfig {
407        topology: builder.topology.clone(),
408        write_uqff: builder.write_uqff.clone(),
409        from_uqff: builder.from_uqff.clone(),
410        max_edge: builder.max_edge,
411        calibration_file: builder.calibration_file.clone(),
412        imatrix: builder.imatrix.clone(),
413        hf_cache_path: builder.hf_cache_path.clone(),
414        matformer_config_path: builder.matformer_config_path.clone(),
415        matformer_slice_name: builder.matformer_slice_name.clone(),
416    };
417
418    if builder.with_logging {
419        initialize_logging();
420    }
421
422    let loader = VisionLoaderBuilder::new(
423        config,
424        builder.chat_template.clone(),
425        builder.tokenizer_json.clone(),
426        Some(builder.model_id.clone()),
427        builder.jinja_explicit.clone(),
428    )
429    .build(builder.loader_type.clone());
430
431    let pipeline = loader.load_model_from_hf(
432        builder.hf_revision.clone(),
433        builder.token_source.clone(),
434        &builder.dtype,
435        &builder
436            .device
437            .clone()
438            .unwrap_or(best_device(builder.force_cpu).unwrap()),
439        !builder.with_logging,
440        builder
441            .device_mapping
442            .clone()
443            .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_vision())),
444        builder.isq,
445        builder.paged_attn_cfg,
446    )?;
447
448    let scheduler_config = match builder.paged_attn_cfg {
449        Some(_) => {
450            let config = pipeline
451                .lock()
452                .await
453                .get_metadata()
454                .cache_config
455                .as_ref()
456                .cloned();
457
458            if let Some(config) = config {
459                SchedulerConfig::PagedAttentionMeta {
460                    max_num_seqs: builder.max_num_seqs,
461                    config,
462                }
463            } else {
464                SchedulerConfig::DefaultScheduler {
465                    method: DefaultSchedulerMethod::Fixed(builder.max_num_seqs.try_into()?),
466                }
467            }
468        }
469        None => SchedulerConfig::DefaultScheduler {
470            method: DefaultSchedulerMethod::Fixed(builder.max_num_seqs.try_into()?),
471        },
472    };
473
474    let engine_config = EngineConfig {
475        throughput_logging_enabled: builder.throughput_logging,
476        search_embedding_model: builder.search_embedding_model,
477        search_callback: builder.search_callback.clone(),
478        tool_callbacks: builder.tool_callbacks.clone(),
479        tool_callbacks_with_tools: builder
480            .tool_callbacks_with_tools
481            .iter()
482            .map(|(k, v)| {
483                (
484                    k.clone(),
485                    mistralrs_core::ToolCallbackWithTool {
486                        callback: v.callback.clone(),
487                        tool: v.tool.clone(),
488                    },
489                )
490            })
491            .collect(),
492        no_kv_cache: false,
493        no_prefix_cache: builder.prefix_cache_n.is_none(),
494        prefix_cache_n: builder.prefix_cache_n.unwrap_or(16),
495        disable_eos_stop: false,
496    };
497
498    // Create loader config for unload/reload support
499    let device = builder
500        .device
501        .clone()
502        .unwrap_or(best_device(builder.force_cpu).unwrap());
503    let device_map_setting = builder
504        .device_mapping
505        .clone()
506        .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_vision()));
507
508    // Convert from_uqff Vec<PathBuf> to semicolon-separated string if present
509    let from_uqff_str = builder.from_uqff.as_ref().map(|paths| {
510        paths
511            .iter()
512            .map(|p| p.to_string_lossy())
513            .collect::<Vec<_>>()
514            .join(";")
515    });
516
517    let loader_config = ModelLoaderConfig {
518        model_selected: ModelSelected::VisionPlain {
519            model_id: builder.model_id.clone(),
520            tokenizer_json: builder.tokenizer_json.clone(),
521            arch: builder.loader_type,
522            dtype: builder.dtype,
523            topology: builder.topology_path.clone(),
524            write_uqff: builder.write_uqff.clone(),
525            from_uqff: from_uqff_str,
526            max_edge: builder.max_edge,
527            calibration_file: builder.calibration_file.clone(),
528            imatrix: builder.imatrix.clone(),
529            max_seq_len: AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN,
530            max_batch_size: AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE,
531            max_num_images: AutoDeviceMapParams::DEFAULT_MAX_NUM_IMAGES,
532            max_image_length: AutoDeviceMapParams::DEFAULT_MAX_IMAGE_LENGTH,
533            hf_cache_path: builder.hf_cache_path.clone(),
534            matformer_config_path: builder.matformer_config_path.clone(),
535            matformer_slice_name: builder.matformer_slice_name.clone(),
536        },
537        token_source: builder.token_source.clone(),
538        hf_revision: builder.hf_revision.clone(),
539        dtype: builder.dtype,
540        device,
541        device_map_setting,
542        isq: builder.isq,
543        paged_attn_config: builder.paged_attn_cfg,
544        silent: !builder.with_logging,
545        chat_template: builder.chat_template.clone(),
546        jinja_explicit: builder.jinja_explicit.clone(),
547    };
548
549    let add_model_config = AddModelConfig {
550        engine_config,
551        mcp_client_config: None,
552        loader_config: Some(loader_config),
553    };
554
555    Ok((pipeline, scheduler_config, add_model_config))
556}
557
558/// Build a GGUF model pipeline from a GgufModelBuilder.
559/// Returns the pipeline, scheduler config, and AddModelConfig needed for Model creation.
560pub async fn build_gguf_pipeline(
561    builder: crate::GgufModelBuilder,
562) -> anyhow::Result<(Arc<Mutex<dyn Pipeline>>, SchedulerConfig, AddModelConfig)> {
563    use crate::best_device;
564    use mistralrs_core::*;
565
566    let config = GGUFSpecificConfig {
567        topology: builder.topology.clone(),
568    };
569
570    if builder.with_logging {
571        initialize_logging();
572    }
573
574    let loader = GGUFLoaderBuilder::new(
575        builder.chat_template.clone(),
576        builder.tok_model_id.clone(),
577        builder.model_id.clone(),
578        builder.files.clone(),
579        config,
580        builder.no_kv_cache,
581        builder.jinja_explicit.clone(),
582    )
583    .build();
584
585    let pipeline = loader.load_model_from_hf(
586        builder.hf_revision.clone(),
587        builder.token_source.clone(),
588        &ModelDType::Auto,
589        &builder
590            .device
591            .clone()
592            .unwrap_or(best_device(builder.force_cpu).unwrap()),
593        !builder.with_logging,
594        builder
595            .device_mapping
596            .clone()
597            .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
598        None,
599        builder.paged_attn_cfg,
600    )?;
601
602    let scheduler_config = match builder.paged_attn_cfg {
603        Some(_) => {
604            let config = pipeline
605                .lock()
606                .await
607                .get_metadata()
608                .cache_config
609                .as_ref()
610                .unwrap()
611                .clone();
612
613            SchedulerConfig::PagedAttentionMeta {
614                max_num_seqs: builder.max_num_seqs,
615                config,
616            }
617        }
618        None => SchedulerConfig::DefaultScheduler {
619            method: DefaultSchedulerMethod::Fixed(builder.max_num_seqs.try_into()?),
620        },
621    };
622
623    let engine_config = EngineConfig {
624        throughput_logging_enabled: builder.throughput_logging,
625        search_embedding_model: builder.search_embedding_model,
626        search_callback: builder.search_callback.clone(),
627        tool_callbacks: builder.tool_callbacks.clone(),
628        tool_callbacks_with_tools: builder
629            .tool_callbacks_with_tools
630            .iter()
631            .map(|(k, v)| {
632                (
633                    k.clone(),
634                    mistralrs_core::ToolCallbackWithTool {
635                        callback: v.callback.clone(),
636                        tool: v.tool.clone(),
637                    },
638                )
639            })
640            .collect(),
641        no_kv_cache: builder.no_kv_cache,
642        no_prefix_cache: builder.prefix_cache_n.is_none(),
643        prefix_cache_n: builder.prefix_cache_n.unwrap_or(16),
644        disable_eos_stop: false,
645    };
646
647    // Create loader config for unload/reload support
648    let device = builder
649        .device
650        .clone()
651        .unwrap_or(best_device(builder.force_cpu).unwrap());
652    let device_map_setting = builder
653        .device_mapping
654        .clone()
655        .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text()));
656
657    let loader_config = ModelLoaderConfig {
658        model_selected: ModelSelected::GGUF {
659            tok_model_id: builder.tok_model_id.clone(),
660            quantized_model_id: builder.model_id.clone(),
661            quantized_filename: builder.files.join(GGUF_MULTI_FILE_DELIMITER),
662            dtype: ModelDType::Auto,
663            topology: builder.topology_path.clone(),
664            max_seq_len: AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN,
665            max_batch_size: AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE,
666        },
667        token_source: builder.token_source.clone(),
668        hf_revision: builder.hf_revision.clone(),
669        dtype: ModelDType::Auto,
670        device,
671        device_map_setting,
672        isq: None,
673        paged_attn_config: builder.paged_attn_cfg,
674        silent: !builder.with_logging,
675        chat_template: builder.chat_template.clone(),
676        jinja_explicit: builder.jinja_explicit.clone(),
677    };
678
679    let add_model_config = AddModelConfig {
680        engine_config,
681        mcp_client_config: None,
682        loader_config: Some(loader_config),
683    };
684
685    Ok((pipeline, scheduler_config, add_model_config))
686}
687
688/// Build a diffusion model pipeline from a DiffusionModelBuilder.
689/// Returns the pipeline, scheduler config, and AddModelConfig needed for Model creation.
690pub async fn build_diffusion_pipeline(
691    builder: crate::DiffusionModelBuilder,
692) -> anyhow::Result<(Arc<Mutex<dyn Pipeline>>, SchedulerConfig, AddModelConfig)> {
693    use crate::best_device;
694    use mistralrs_core::*;
695
696    if builder.with_logging {
697        initialize_logging();
698    }
699
700    let loader = DiffusionLoaderBuilder::new(Some(builder.model_id.clone()))
701        .build(builder.loader_type.clone());
702
703    let pipeline = loader.load_model_from_hf(
704        builder.hf_revision.clone(),
705        builder.token_source.clone(),
706        &builder.dtype,
707        &best_device(builder.force_cpu)?,
708        !builder.with_logging,
709        DeviceMapSetting::Auto(AutoDeviceMapParams::default_text()),
710        None,
711        None,
712    )?;
713
714    let scheduler_config = SchedulerConfig::DefaultScheduler {
715        method: DefaultSchedulerMethod::Fixed(builder.max_num_seqs.try_into()?),
716    };
717
718    let engine_config = EngineConfig::default();
719
720    // Create loader config for unload/reload support
721    let device = best_device(builder.force_cpu)?;
722
723    let loader_config = ModelLoaderConfig {
724        model_selected: ModelSelected::DiffusionPlain {
725            model_id: builder.model_id.clone(),
726            arch: builder.loader_type,
727            dtype: builder.dtype,
728        },
729        token_source: builder.token_source.clone(),
730        hf_revision: builder.hf_revision.clone(),
731        dtype: builder.dtype,
732        device,
733        device_map_setting: DeviceMapSetting::Auto(AutoDeviceMapParams::default_text()),
734        isq: None,
735        paged_attn_config: None,
736        silent: !builder.with_logging,
737        chat_template: None,
738        jinja_explicit: None,
739    };
740
741    let add_model_config = AddModelConfig {
742        engine_config,
743        mcp_client_config: None,
744        loader_config: Some(loader_config),
745    };
746
747    Ok((pipeline, scheduler_config, add_model_config))
748}
749
750/// Build a speech model pipeline from a SpeechModelBuilder.
751/// Returns the pipeline, scheduler config, and AddModelConfig needed for Model creation.
752pub async fn build_speech_pipeline(
753    builder: crate::SpeechModelBuilder,
754) -> anyhow::Result<(Arc<Mutex<dyn Pipeline>>, SchedulerConfig, AddModelConfig)> {
755    use crate::best_device;
756    use mistralrs_core::*;
757
758    if builder.with_logging {
759        initialize_logging();
760    }
761
762    let loader = SpeechLoader {
763        model_id: builder.model_id.clone(),
764        dac_model_id: builder.dac_model_id.clone(),
765        arch: builder.loader_type,
766        cfg: builder.cfg,
767    };
768
769    let pipeline = loader.load_model_from_hf(
770        builder.hf_revision.clone(),
771        builder.token_source.clone(),
772        &builder.dtype,
773        &best_device(builder.force_cpu)?,
774        !builder.with_logging,
775        DeviceMapSetting::Auto(AutoDeviceMapParams::default_text()),
776        None,
777        None,
778    )?;
779
780    let scheduler_config = SchedulerConfig::DefaultScheduler {
781        method: DefaultSchedulerMethod::Fixed(builder.max_num_seqs.try_into()?),
782    };
783
784    let engine_config = EngineConfig::default();
785
786    // Create loader config for unload/reload support
787    let device = best_device(builder.force_cpu)?;
788
789    let loader_config = ModelLoaderConfig {
790        model_selected: ModelSelected::Speech {
791            model_id: builder.model_id.clone(),
792            dac_model_id: builder.dac_model_id.clone(),
793            arch: builder.loader_type,
794            dtype: builder.dtype,
795        },
796        token_source: builder.token_source.clone(),
797        hf_revision: builder.hf_revision.clone(),
798        dtype: builder.dtype,
799        device,
800        device_map_setting: DeviceMapSetting::Auto(AutoDeviceMapParams::default_text()),
801        isq: None,
802        paged_attn_config: None,
803        silent: !builder.with_logging,
804        chat_template: None,
805        jinja_explicit: None,
806    };
807
808    let add_model_config = AddModelConfig {
809        engine_config,
810        mcp_client_config: None,
811        loader_config: Some(loader_config),
812    };
813
814    Ok((pipeline, scheduler_config, add_model_config))
815}
816
817/// Build an embedding model pipeline from an EmbeddingModelBuilder.
818/// Returns the pipeline, scheduler config, and AddModelConfig needed for Model creation.
819pub async fn build_embedding_pipeline(
820    builder: crate::EmbeddingModelBuilder,
821) -> anyhow::Result<(Arc<Mutex<dyn Pipeline>>, SchedulerConfig, AddModelConfig)> {
822    use crate::best_device;
823    use mistralrs_core::*;
824
825    let config = EmbeddingSpecificConfig {
826        topology: builder.topology.clone(),
827        write_uqff: builder.write_uqff.clone(),
828        from_uqff: builder.from_uqff.clone(),
829        hf_cache_path: builder.hf_cache_path.clone(),
830    };
831
832    if builder.with_logging {
833        initialize_logging();
834    }
835
836    let loader = EmbeddingLoaderBuilder::new(
837        config,
838        builder.tokenizer_json.clone(),
839        Some(builder.model_id.clone()),
840    )
841    .build(builder.loader_type.clone());
842
843    let pipeline = loader.load_model_from_hf(
844        builder.hf_revision.clone(),
845        builder.token_source.clone(),
846        &builder.dtype,
847        &builder
848            .device
849            .clone()
850            .unwrap_or(best_device(builder.force_cpu).unwrap()),
851        !builder.with_logging,
852        builder
853            .device_mapping
854            .clone()
855            .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
856        builder.isq,
857        None,
858    )?;
859
860    let scheduler_config = SchedulerConfig::DefaultScheduler {
861        method: DefaultSchedulerMethod::Fixed(builder.max_num_seqs.try_into()?),
862    };
863
864    let engine_config = EngineConfig {
865        throughput_logging_enabled: builder.throughput_logging,
866        ..Default::default()
867    };
868
869    // Create loader config for unload/reload support
870    let device = builder
871        .device
872        .clone()
873        .unwrap_or(best_device(builder.force_cpu).unwrap());
874    let device_map_setting = builder
875        .device_mapping
876        .clone()
877        .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text()));
878
879    // Convert from_uqff Vec<PathBuf> to semicolon-separated string if present
880    let from_uqff_str = builder.from_uqff.as_ref().map(|paths| {
881        paths
882            .iter()
883            .map(|p| p.to_string_lossy())
884            .collect::<Vec<_>>()
885            .join(";")
886    });
887
888    let loader_config = ModelLoaderConfig {
889        model_selected: ModelSelected::Embedding {
890            model_id: builder.model_id.clone(),
891            tokenizer_json: builder.tokenizer_json.clone(),
892            arch: builder.loader_type,
893            dtype: builder.dtype,
894            topology: builder.topology_path.clone(),
895            write_uqff: builder.write_uqff.clone(),
896            from_uqff: from_uqff_str,
897            hf_cache_path: builder.hf_cache_path.clone(),
898        },
899        token_source: builder.token_source.clone(),
900        hf_revision: builder.hf_revision.clone(),
901        dtype: builder.dtype,
902        device,
903        device_map_setting,
904        isq: builder.isq,
905        paged_attn_config: None,
906        silent: !builder.with_logging,
907        chat_template: None,
908        jinja_explicit: None,
909    };
910
911    let add_model_config = AddModelConfig {
912        engine_config,
913        mcp_client_config: None,
914        loader_config: Some(loader_config),
915    };
916
917    Ok((pipeline, scheduler_config, add_model_config))
918}