1use mistralrs_core::{AddModelConfig, Pipeline, SchedulerConfig};
2use std::sync::Arc;
3use tokio::sync::Mutex;
4
5use crate::Model;
6
7pub enum AnyModelBuilder {
9 Text(crate::TextModelBuilder),
10 Vision(crate::VisionModelBuilder),
11 Gguf(crate::GgufModelBuilder),
12 Diffusion(crate::DiffusionModelBuilder),
13 Speech(crate::SpeechModelBuilder),
14 Embedding(crate::EmbeddingModelBuilder),
15}
16
17impl AnyModelBuilder {
18 pub fn model_id(&self) -> String {
20 match self {
21 AnyModelBuilder::Text(b) => b.model_id.clone(),
22 AnyModelBuilder::Vision(b) => b.model_id.clone(),
23 AnyModelBuilder::Gguf(b) => b.model_id.clone(),
24 AnyModelBuilder::Diffusion(b) => b.model_id.clone(),
25 AnyModelBuilder::Speech(b) => b.model_id.clone(),
26 AnyModelBuilder::Embedding(b) => b.model_id.clone(),
27 }
28 }
29
30 pub async fn build_pipeline(
32 self,
33 ) -> anyhow::Result<(Arc<Mutex<dyn Pipeline>>, SchedulerConfig, AddModelConfig)> {
34 match self {
35 AnyModelBuilder::Text(b) => build_text_pipeline(b).await,
36 AnyModelBuilder::Vision(b) => build_vision_pipeline(b).await,
37 AnyModelBuilder::Gguf(b) => build_gguf_pipeline(b).await,
38 AnyModelBuilder::Diffusion(b) => build_diffusion_pipeline(b).await,
39 AnyModelBuilder::Speech(b) => build_speech_pipeline(b).await,
40 AnyModelBuilder::Embedding(b) => build_embedding_pipeline(b).await,
41 }
42 }
43}
44
45impl From<crate::TextModelBuilder> for AnyModelBuilder {
47 fn from(b: crate::TextModelBuilder) -> Self {
48 AnyModelBuilder::Text(b)
49 }
50}
51
52impl From<crate::VisionModelBuilder> for AnyModelBuilder {
53 fn from(b: crate::VisionModelBuilder) -> Self {
54 AnyModelBuilder::Vision(b)
55 }
56}
57
58impl From<crate::GgufModelBuilder> for AnyModelBuilder {
59 fn from(b: crate::GgufModelBuilder) -> Self {
60 AnyModelBuilder::Gguf(b)
61 }
62}
63
64impl From<crate::DiffusionModelBuilder> for AnyModelBuilder {
65 fn from(b: crate::DiffusionModelBuilder) -> Self {
66 AnyModelBuilder::Diffusion(b)
67 }
68}
69
70impl From<crate::SpeechModelBuilder> for AnyModelBuilder {
71 fn from(b: crate::SpeechModelBuilder) -> Self {
72 AnyModelBuilder::Speech(b)
73 }
74}
75
76impl From<crate::EmbeddingModelBuilder> for AnyModelBuilder {
77 fn from(b: crate::EmbeddingModelBuilder) -> Self {
78 AnyModelBuilder::Embedding(b)
79 }
80}
81
82pub struct MultiModelBuilder {
84 builders: Vec<AnyModelBuilder>,
85 default_model_id: Option<String>,
86}
87
88impl Default for MultiModelBuilder {
89 fn default() -> Self {
90 Self::new()
91 }
92}
93
94impl MultiModelBuilder {
95 pub fn new() -> Self {
97 Self {
98 builders: Vec::new(),
99 default_model_id: None,
100 }
101 }
102
103 pub fn add_model<B: Into<AnyModelBuilder>>(mut self, builder: B) -> Self {
105 self.builders.push(builder.into());
106 self
107 }
108
109 pub fn with_default_model(mut self, model_id: impl ToString) -> Self {
111 self.default_model_id = Some(model_id.to_string());
112 self
113 }
114
115 pub async fn build(self) -> anyhow::Result<Model> {
117 if self.builders.is_empty() {
118 anyhow::bail!("MultiModelBuilder requires at least one model to be added");
119 }
120
121 let mut builders_iter = self.builders.into_iter();
123 let first_builder = builders_iter.next().unwrap();
124
125 let (pipeline, scheduler_config, add_model_config) = first_builder.build_pipeline().await?;
126
127 let mut runner_builder = mistralrs_core::MistralRsBuilder::new(
130 pipeline,
131 scheduler_config,
132 add_model_config.engine_config.throughput_logging_enabled,
133 add_model_config.engine_config.search_embedding_model,
134 );
135
136 if let Some(cb) = add_model_config.engine_config.search_callback.clone() {
137 runner_builder = runner_builder.with_search_callback(cb);
138 }
139
140 for (name, cb) in &add_model_config.engine_config.tool_callbacks {
141 runner_builder = runner_builder.with_tool_callback(name.clone(), cb.clone());
142 }
143
144 for (name, callback_with_tool) in &add_model_config.engine_config.tool_callbacks_with_tools
145 {
146 runner_builder = runner_builder.with_tool_callback_and_tool(
147 name.clone(),
148 callback_with_tool.callback.clone(),
149 callback_with_tool.tool.clone(),
150 );
151 }
152
153 if let Some(mcp_config) = add_model_config.mcp_client_config.clone() {
154 runner_builder = runner_builder.with_mcp_client(mcp_config);
155 }
156
157 if let Some(loader_config) = add_model_config.loader_config.clone() {
158 runner_builder = runner_builder.with_loader_config(loader_config);
159 }
160
161 runner_builder = runner_builder
162 .with_no_kv_cache(add_model_config.engine_config.no_kv_cache)
163 .with_no_prefix_cache(add_model_config.engine_config.no_prefix_cache)
164 .with_prefix_cache_n(add_model_config.engine_config.prefix_cache_n);
165
166 let mistralrs = runner_builder.build().await;
167
168 for builder in builders_iter {
170 let model_id = builder.model_id();
171 let (pipeline, scheduler_config, add_model_config) = builder.build_pipeline().await?;
172 mistralrs
173 .add_model(model_id, pipeline, scheduler_config, add_model_config)
174 .await
175 .map_err(|e| anyhow::anyhow!(e))?;
176 }
177
178 if let Some(default_id) = self.default_model_id {
180 mistralrs
181 .set_default_model_id(&default_id)
182 .map_err(|e| anyhow::anyhow!(e))?;
183 }
184 Ok(Model::new(mistralrs))
187 }
188}
189
190pub async fn build_model_from_pipeline(
196 pipeline: Arc<Mutex<dyn mistralrs_core::Pipeline>>,
197 scheduler_config: SchedulerConfig,
198 add_model_config: AddModelConfig,
199) -> Model {
200 let mut runner_builder = mistralrs_core::MistralRsBuilder::new(
201 pipeline,
202 scheduler_config,
203 add_model_config.engine_config.throughput_logging_enabled,
204 add_model_config.engine_config.search_embedding_model,
205 );
206
207 if let Some(cb) = add_model_config.engine_config.search_callback.clone() {
208 runner_builder = runner_builder.with_search_callback(cb);
209 }
210
211 for (name, cb) in &add_model_config.engine_config.tool_callbacks {
212 runner_builder = runner_builder.with_tool_callback(name.clone(), cb.clone());
213 }
214
215 for (name, callback_with_tool) in &add_model_config.engine_config.tool_callbacks_with_tools {
216 runner_builder = runner_builder.with_tool_callback_and_tool(
217 name.clone(),
218 callback_with_tool.callback.clone(),
219 callback_with_tool.tool.clone(),
220 );
221 }
222
223 if let Some(mcp_config) = add_model_config.mcp_client_config.clone() {
224 runner_builder = runner_builder.with_mcp_client(mcp_config);
225 }
226
227 if let Some(loader_config) = add_model_config.loader_config.clone() {
228 runner_builder = runner_builder.with_loader_config(loader_config);
229 }
230
231 runner_builder = runner_builder
232 .with_no_kv_cache(add_model_config.engine_config.no_kv_cache)
233 .with_no_prefix_cache(add_model_config.engine_config.no_prefix_cache)
234 .with_prefix_cache_n(add_model_config.engine_config.prefix_cache_n);
235
236 Model::new(runner_builder.build().await)
237}
238
239pub async fn build_text_pipeline(
242 builder: crate::TextModelBuilder,
243) -> anyhow::Result<(Arc<Mutex<dyn Pipeline>>, SchedulerConfig, AddModelConfig)> {
244 use crate::best_device;
245 use mistralrs_core::*;
246
247 let config = NormalSpecificConfig {
248 topology: builder.topology.clone(),
249 organization: builder.organization,
250 write_uqff: builder.write_uqff.clone(),
251 from_uqff: builder.from_uqff.clone(),
252 imatrix: builder.imatrix.clone(),
253 calibration_file: builder.calibration_file.clone(),
254 hf_cache_path: builder.hf_cache_path.clone(),
255 matformer_config_path: builder.matformer_config_path.clone(),
256 matformer_slice_name: builder.matformer_slice_name.clone(),
257 };
258
259 if builder.with_logging {
260 initialize_logging();
261 }
262
263 let loader = NormalLoaderBuilder::new(
264 config,
265 builder.chat_template.clone(),
266 builder.tokenizer_json.clone(),
267 Some(builder.model_id.clone()),
268 builder.no_kv_cache,
269 builder.jinja_explicit.clone(),
270 )
271 .build(builder.loader_type.clone())?;
272
273 let pipeline = loader.load_model_from_hf(
274 builder.hf_revision.clone(),
275 builder.token_source.clone(),
276 &builder.dtype,
277 &builder
278 .device
279 .clone()
280 .unwrap_or(best_device(builder.force_cpu).unwrap()),
281 !builder.with_logging,
282 builder
283 .device_mapping
284 .clone()
285 .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
286 builder.isq,
287 builder.paged_attn_cfg,
288 )?;
289
290 let scheduler_config = match builder.paged_attn_cfg {
291 Some(_) => {
292 let config = pipeline
293 .lock()
294 .await
295 .get_metadata()
296 .cache_config
297 .as_ref()
298 .cloned();
299
300 if let Some(config) = config {
301 SchedulerConfig::PagedAttentionMeta {
302 max_num_seqs: builder.max_num_seqs,
303 config,
304 }
305 } else {
306 SchedulerConfig::DefaultScheduler {
307 method: DefaultSchedulerMethod::Fixed(builder.max_num_seqs.try_into()?),
308 }
309 }
310 }
311 None => SchedulerConfig::DefaultScheduler {
312 method: DefaultSchedulerMethod::Fixed(builder.max_num_seqs.try_into()?),
313 },
314 };
315
316 let engine_config = EngineConfig {
317 throughput_logging_enabled: builder.throughput_logging,
318 search_embedding_model: builder.search_embedding_model,
319 search_callback: builder.search_callback.clone(),
320 tool_callbacks: builder.tool_callbacks.clone(),
321 tool_callbacks_with_tools: builder
322 .tool_callbacks_with_tools
323 .iter()
324 .map(|(k, v)| {
325 (
326 k.clone(),
327 mistralrs_core::ToolCallbackWithTool {
328 callback: v.callback.clone(),
329 tool: v.tool.clone(),
330 },
331 )
332 })
333 .collect(),
334 no_kv_cache: builder.no_kv_cache,
335 no_prefix_cache: builder.prefix_cache_n.is_none(),
336 prefix_cache_n: builder.prefix_cache_n.unwrap_or(16),
337 disable_eos_stop: false,
338 };
339
340 let device = builder
342 .device
343 .clone()
344 .unwrap_or(best_device(builder.force_cpu).unwrap());
345 let device_map_setting = builder
346 .device_mapping
347 .clone()
348 .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text()));
349
350 let from_uqff_str = builder.from_uqff.as_ref().map(|paths| {
352 paths
353 .iter()
354 .map(|p| p.to_string_lossy())
355 .collect::<Vec<_>>()
356 .join(";")
357 });
358
359 let loader_config = ModelLoaderConfig {
360 model_selected: ModelSelected::Plain {
361 model_id: builder.model_id.clone(),
362 tokenizer_json: builder.tokenizer_json.clone(),
363 arch: builder.loader_type,
364 dtype: builder.dtype,
365 topology: builder.topology_path.clone(),
366 organization: Some(builder.organization),
367 write_uqff: builder.write_uqff.clone(),
368 from_uqff: from_uqff_str,
369 imatrix: builder.imatrix.clone(),
370 calibration_file: builder.calibration_file.clone(),
371 max_seq_len: AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN,
372 max_batch_size: AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE,
373 hf_cache_path: builder.hf_cache_path.clone(),
374 matformer_config_path: builder.matformer_config_path.clone(),
375 matformer_slice_name: builder.matformer_slice_name.clone(),
376 },
377 token_source: builder.token_source.clone(),
378 hf_revision: builder.hf_revision.clone(),
379 dtype: builder.dtype,
380 device,
381 device_map_setting,
382 isq: builder.isq,
383 paged_attn_config: builder.paged_attn_cfg,
384 silent: !builder.with_logging,
385 chat_template: builder.chat_template.clone(),
386 jinja_explicit: builder.jinja_explicit.clone(),
387 };
388
389 let add_model_config = AddModelConfig {
390 engine_config,
391 mcp_client_config: builder.mcp_client_config.clone(),
392 loader_config: Some(loader_config),
393 };
394
395 Ok((pipeline, scheduler_config, add_model_config))
396}
397
398pub async fn build_vision_pipeline(
401 builder: crate::VisionModelBuilder,
402) -> anyhow::Result<(Arc<Mutex<dyn Pipeline>>, SchedulerConfig, AddModelConfig)> {
403 use crate::best_device;
404 use mistralrs_core::*;
405
406 let config = VisionSpecificConfig {
407 topology: builder.topology.clone(),
408 write_uqff: builder.write_uqff.clone(),
409 from_uqff: builder.from_uqff.clone(),
410 max_edge: builder.max_edge,
411 calibration_file: builder.calibration_file.clone(),
412 imatrix: builder.imatrix.clone(),
413 hf_cache_path: builder.hf_cache_path.clone(),
414 matformer_config_path: builder.matformer_config_path.clone(),
415 matformer_slice_name: builder.matformer_slice_name.clone(),
416 };
417
418 if builder.with_logging {
419 initialize_logging();
420 }
421
422 let loader = VisionLoaderBuilder::new(
423 config,
424 builder.chat_template.clone(),
425 builder.tokenizer_json.clone(),
426 Some(builder.model_id.clone()),
427 builder.jinja_explicit.clone(),
428 )
429 .build(builder.loader_type.clone());
430
431 let pipeline = loader.load_model_from_hf(
432 builder.hf_revision.clone(),
433 builder.token_source.clone(),
434 &builder.dtype,
435 &builder
436 .device
437 .clone()
438 .unwrap_or(best_device(builder.force_cpu).unwrap()),
439 !builder.with_logging,
440 builder
441 .device_mapping
442 .clone()
443 .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_vision())),
444 builder.isq,
445 builder.paged_attn_cfg,
446 )?;
447
448 let scheduler_config = match builder.paged_attn_cfg {
449 Some(_) => {
450 let config = pipeline
451 .lock()
452 .await
453 .get_metadata()
454 .cache_config
455 .as_ref()
456 .cloned();
457
458 if let Some(config) = config {
459 SchedulerConfig::PagedAttentionMeta {
460 max_num_seqs: builder.max_num_seqs,
461 config,
462 }
463 } else {
464 SchedulerConfig::DefaultScheduler {
465 method: DefaultSchedulerMethod::Fixed(builder.max_num_seqs.try_into()?),
466 }
467 }
468 }
469 None => SchedulerConfig::DefaultScheduler {
470 method: DefaultSchedulerMethod::Fixed(builder.max_num_seqs.try_into()?),
471 },
472 };
473
474 let engine_config = EngineConfig {
475 throughput_logging_enabled: builder.throughput_logging,
476 search_embedding_model: builder.search_embedding_model,
477 search_callback: builder.search_callback.clone(),
478 tool_callbacks: builder.tool_callbacks.clone(),
479 tool_callbacks_with_tools: builder
480 .tool_callbacks_with_tools
481 .iter()
482 .map(|(k, v)| {
483 (
484 k.clone(),
485 mistralrs_core::ToolCallbackWithTool {
486 callback: v.callback.clone(),
487 tool: v.tool.clone(),
488 },
489 )
490 })
491 .collect(),
492 no_kv_cache: false,
493 no_prefix_cache: builder.prefix_cache_n.is_none(),
494 prefix_cache_n: builder.prefix_cache_n.unwrap_or(16),
495 disable_eos_stop: false,
496 };
497
498 let device = builder
500 .device
501 .clone()
502 .unwrap_or(best_device(builder.force_cpu).unwrap());
503 let device_map_setting = builder
504 .device_mapping
505 .clone()
506 .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_vision()));
507
508 let from_uqff_str = builder.from_uqff.as_ref().map(|paths| {
510 paths
511 .iter()
512 .map(|p| p.to_string_lossy())
513 .collect::<Vec<_>>()
514 .join(";")
515 });
516
517 let loader_config = ModelLoaderConfig {
518 model_selected: ModelSelected::VisionPlain {
519 model_id: builder.model_id.clone(),
520 tokenizer_json: builder.tokenizer_json.clone(),
521 arch: builder.loader_type,
522 dtype: builder.dtype,
523 topology: builder.topology_path.clone(),
524 write_uqff: builder.write_uqff.clone(),
525 from_uqff: from_uqff_str,
526 max_edge: builder.max_edge,
527 calibration_file: builder.calibration_file.clone(),
528 imatrix: builder.imatrix.clone(),
529 max_seq_len: AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN,
530 max_batch_size: AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE,
531 max_num_images: AutoDeviceMapParams::DEFAULT_MAX_NUM_IMAGES,
532 max_image_length: AutoDeviceMapParams::DEFAULT_MAX_IMAGE_LENGTH,
533 hf_cache_path: builder.hf_cache_path.clone(),
534 matformer_config_path: builder.matformer_config_path.clone(),
535 matformer_slice_name: builder.matformer_slice_name.clone(),
536 },
537 token_source: builder.token_source.clone(),
538 hf_revision: builder.hf_revision.clone(),
539 dtype: builder.dtype,
540 device,
541 device_map_setting,
542 isq: builder.isq,
543 paged_attn_config: builder.paged_attn_cfg,
544 silent: !builder.with_logging,
545 chat_template: builder.chat_template.clone(),
546 jinja_explicit: builder.jinja_explicit.clone(),
547 };
548
549 let add_model_config = AddModelConfig {
550 engine_config,
551 mcp_client_config: None,
552 loader_config: Some(loader_config),
553 };
554
555 Ok((pipeline, scheduler_config, add_model_config))
556}
557
558pub async fn build_gguf_pipeline(
561 builder: crate::GgufModelBuilder,
562) -> anyhow::Result<(Arc<Mutex<dyn Pipeline>>, SchedulerConfig, AddModelConfig)> {
563 use crate::best_device;
564 use mistralrs_core::*;
565
566 let config = GGUFSpecificConfig {
567 topology: builder.topology.clone(),
568 };
569
570 if builder.with_logging {
571 initialize_logging();
572 }
573
574 let loader = GGUFLoaderBuilder::new(
575 builder.chat_template.clone(),
576 builder.tok_model_id.clone(),
577 builder.model_id.clone(),
578 builder.files.clone(),
579 config,
580 builder.no_kv_cache,
581 builder.jinja_explicit.clone(),
582 )
583 .build();
584
585 let pipeline = loader.load_model_from_hf(
586 builder.hf_revision.clone(),
587 builder.token_source.clone(),
588 &ModelDType::Auto,
589 &builder
590 .device
591 .clone()
592 .unwrap_or(best_device(builder.force_cpu).unwrap()),
593 !builder.with_logging,
594 builder
595 .device_mapping
596 .clone()
597 .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
598 None,
599 builder.paged_attn_cfg,
600 )?;
601
602 let scheduler_config = match builder.paged_attn_cfg {
603 Some(_) => {
604 let config = pipeline
605 .lock()
606 .await
607 .get_metadata()
608 .cache_config
609 .as_ref()
610 .unwrap()
611 .clone();
612
613 SchedulerConfig::PagedAttentionMeta {
614 max_num_seqs: builder.max_num_seqs,
615 config,
616 }
617 }
618 None => SchedulerConfig::DefaultScheduler {
619 method: DefaultSchedulerMethod::Fixed(builder.max_num_seqs.try_into()?),
620 },
621 };
622
623 let engine_config = EngineConfig {
624 throughput_logging_enabled: builder.throughput_logging,
625 search_embedding_model: builder.search_embedding_model,
626 search_callback: builder.search_callback.clone(),
627 tool_callbacks: builder.tool_callbacks.clone(),
628 tool_callbacks_with_tools: builder
629 .tool_callbacks_with_tools
630 .iter()
631 .map(|(k, v)| {
632 (
633 k.clone(),
634 mistralrs_core::ToolCallbackWithTool {
635 callback: v.callback.clone(),
636 tool: v.tool.clone(),
637 },
638 )
639 })
640 .collect(),
641 no_kv_cache: builder.no_kv_cache,
642 no_prefix_cache: builder.prefix_cache_n.is_none(),
643 prefix_cache_n: builder.prefix_cache_n.unwrap_or(16),
644 disable_eos_stop: false,
645 };
646
647 let device = builder
649 .device
650 .clone()
651 .unwrap_or(best_device(builder.force_cpu).unwrap());
652 let device_map_setting = builder
653 .device_mapping
654 .clone()
655 .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text()));
656
657 let loader_config = ModelLoaderConfig {
658 model_selected: ModelSelected::GGUF {
659 tok_model_id: builder.tok_model_id.clone(),
660 quantized_model_id: builder.model_id.clone(),
661 quantized_filename: builder.files.join(GGUF_MULTI_FILE_DELIMITER),
662 dtype: ModelDType::Auto,
663 topology: builder.topology_path.clone(),
664 max_seq_len: AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN,
665 max_batch_size: AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE,
666 },
667 token_source: builder.token_source.clone(),
668 hf_revision: builder.hf_revision.clone(),
669 dtype: ModelDType::Auto,
670 device,
671 device_map_setting,
672 isq: None,
673 paged_attn_config: builder.paged_attn_cfg,
674 silent: !builder.with_logging,
675 chat_template: builder.chat_template.clone(),
676 jinja_explicit: builder.jinja_explicit.clone(),
677 };
678
679 let add_model_config = AddModelConfig {
680 engine_config,
681 mcp_client_config: None,
682 loader_config: Some(loader_config),
683 };
684
685 Ok((pipeline, scheduler_config, add_model_config))
686}
687
688pub async fn build_diffusion_pipeline(
691 builder: crate::DiffusionModelBuilder,
692) -> anyhow::Result<(Arc<Mutex<dyn Pipeline>>, SchedulerConfig, AddModelConfig)> {
693 use crate::best_device;
694 use mistralrs_core::*;
695
696 if builder.with_logging {
697 initialize_logging();
698 }
699
700 let loader = DiffusionLoaderBuilder::new(Some(builder.model_id.clone()))
701 .build(builder.loader_type.clone());
702
703 let pipeline = loader.load_model_from_hf(
704 builder.hf_revision.clone(),
705 builder.token_source.clone(),
706 &builder.dtype,
707 &best_device(builder.force_cpu)?,
708 !builder.with_logging,
709 DeviceMapSetting::Auto(AutoDeviceMapParams::default_text()),
710 None,
711 None,
712 )?;
713
714 let scheduler_config = SchedulerConfig::DefaultScheduler {
715 method: DefaultSchedulerMethod::Fixed(builder.max_num_seqs.try_into()?),
716 };
717
718 let engine_config = EngineConfig::default();
719
720 let device = best_device(builder.force_cpu)?;
722
723 let loader_config = ModelLoaderConfig {
724 model_selected: ModelSelected::DiffusionPlain {
725 model_id: builder.model_id.clone(),
726 arch: builder.loader_type,
727 dtype: builder.dtype,
728 },
729 token_source: builder.token_source.clone(),
730 hf_revision: builder.hf_revision.clone(),
731 dtype: builder.dtype,
732 device,
733 device_map_setting: DeviceMapSetting::Auto(AutoDeviceMapParams::default_text()),
734 isq: None,
735 paged_attn_config: None,
736 silent: !builder.with_logging,
737 chat_template: None,
738 jinja_explicit: None,
739 };
740
741 let add_model_config = AddModelConfig {
742 engine_config,
743 mcp_client_config: None,
744 loader_config: Some(loader_config),
745 };
746
747 Ok((pipeline, scheduler_config, add_model_config))
748}
749
750pub async fn build_speech_pipeline(
753 builder: crate::SpeechModelBuilder,
754) -> anyhow::Result<(Arc<Mutex<dyn Pipeline>>, SchedulerConfig, AddModelConfig)> {
755 use crate::best_device;
756 use mistralrs_core::*;
757
758 if builder.with_logging {
759 initialize_logging();
760 }
761
762 let loader = SpeechLoader {
763 model_id: builder.model_id.clone(),
764 dac_model_id: builder.dac_model_id.clone(),
765 arch: builder.loader_type,
766 cfg: builder.cfg,
767 };
768
769 let pipeline = loader.load_model_from_hf(
770 builder.hf_revision.clone(),
771 builder.token_source.clone(),
772 &builder.dtype,
773 &best_device(builder.force_cpu)?,
774 !builder.with_logging,
775 DeviceMapSetting::Auto(AutoDeviceMapParams::default_text()),
776 None,
777 None,
778 )?;
779
780 let scheduler_config = SchedulerConfig::DefaultScheduler {
781 method: DefaultSchedulerMethod::Fixed(builder.max_num_seqs.try_into()?),
782 };
783
784 let engine_config = EngineConfig::default();
785
786 let device = best_device(builder.force_cpu)?;
788
789 let loader_config = ModelLoaderConfig {
790 model_selected: ModelSelected::Speech {
791 model_id: builder.model_id.clone(),
792 dac_model_id: builder.dac_model_id.clone(),
793 arch: builder.loader_type,
794 dtype: builder.dtype,
795 },
796 token_source: builder.token_source.clone(),
797 hf_revision: builder.hf_revision.clone(),
798 dtype: builder.dtype,
799 device,
800 device_map_setting: DeviceMapSetting::Auto(AutoDeviceMapParams::default_text()),
801 isq: None,
802 paged_attn_config: None,
803 silent: !builder.with_logging,
804 chat_template: None,
805 jinja_explicit: None,
806 };
807
808 let add_model_config = AddModelConfig {
809 engine_config,
810 mcp_client_config: None,
811 loader_config: Some(loader_config),
812 };
813
814 Ok((pipeline, scheduler_config, add_model_config))
815}
816
817pub async fn build_embedding_pipeline(
820 builder: crate::EmbeddingModelBuilder,
821) -> anyhow::Result<(Arc<Mutex<dyn Pipeline>>, SchedulerConfig, AddModelConfig)> {
822 use crate::best_device;
823 use mistralrs_core::*;
824
825 let config = EmbeddingSpecificConfig {
826 topology: builder.topology.clone(),
827 write_uqff: builder.write_uqff.clone(),
828 from_uqff: builder.from_uqff.clone(),
829 hf_cache_path: builder.hf_cache_path.clone(),
830 };
831
832 if builder.with_logging {
833 initialize_logging();
834 }
835
836 let loader = EmbeddingLoaderBuilder::new(
837 config,
838 builder.tokenizer_json.clone(),
839 Some(builder.model_id.clone()),
840 )
841 .build(builder.loader_type.clone());
842
843 let pipeline = loader.load_model_from_hf(
844 builder.hf_revision.clone(),
845 builder.token_source.clone(),
846 &builder.dtype,
847 &builder
848 .device
849 .clone()
850 .unwrap_or(best_device(builder.force_cpu).unwrap()),
851 !builder.with_logging,
852 builder
853 .device_mapping
854 .clone()
855 .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
856 builder.isq,
857 None,
858 )?;
859
860 let scheduler_config = SchedulerConfig::DefaultScheduler {
861 method: DefaultSchedulerMethod::Fixed(builder.max_num_seqs.try_into()?),
862 };
863
864 let engine_config = EngineConfig {
865 throughput_logging_enabled: builder.throughput_logging,
866 ..Default::default()
867 };
868
869 let device = builder
871 .device
872 .clone()
873 .unwrap_or(best_device(builder.force_cpu).unwrap());
874 let device_map_setting = builder
875 .device_mapping
876 .clone()
877 .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text()));
878
879 let from_uqff_str = builder.from_uqff.as_ref().map(|paths| {
881 paths
882 .iter()
883 .map(|p| p.to_string_lossy())
884 .collect::<Vec<_>>()
885 .join(";")
886 });
887
888 let loader_config = ModelLoaderConfig {
889 model_selected: ModelSelected::Embedding {
890 model_id: builder.model_id.clone(),
891 tokenizer_json: builder.tokenizer_json.clone(),
892 arch: builder.loader_type,
893 dtype: builder.dtype,
894 topology: builder.topology_path.clone(),
895 write_uqff: builder.write_uqff.clone(),
896 from_uqff: from_uqff_str,
897 hf_cache_path: builder.hf_cache_path.clone(),
898 },
899 token_source: builder.token_source.clone(),
900 hf_revision: builder.hf_revision.clone(),
901 dtype: builder.dtype,
902 device,
903 device_map_setting,
904 isq: builder.isq,
905 paged_attn_config: None,
906 silent: !builder.with_logging,
907 chat_template: None,
908 jinja_explicit: None,
909 };
910
911 let add_model_config = AddModelConfig {
912 engine_config,
913 mcp_client_config: None,
914 loader_config: Some(loader_config),
915 };
916
917 Ok((pipeline, scheduler_config, add_model_config))
918}