1use super::isq::UqffFullSer;
2use super::{
3 get_model_paths, get_xlora_paths, AdapterKind, AnyMoePipelineMixin, CacheManagerMixin,
4 EitherCache, ForwardInputsResult, GeneralMetadata, IsqPipelineMixin, Loader, MetadataMixin,
5 ModelCategory, ModelKind, ModelPaths, PreProcessingMixin, TokenSource,
6};
7use crate::attention::ATTENTION_CHUNK_SIZE;
8use crate::device_map::{self, DeviceMapper};
9use crate::distributed::{self, WorkerTransferData};
10use crate::embedding_models::inputs_processor::{EmbeddingProcessor, ModelInputs};
11use crate::embedding_models::{Dense, DenseActivation, Normalize, Pooling};
12use crate::embedding_normal_model_loader;
13use crate::embedding_normal_model_loader_sharded;
14use crate::get_embedding_paths;
15use crate::paged_attention::AttentionImplementation;
16use crate::pipeline::loaders::auto_device_map;
17use crate::pipeline::loaders::QuantizationConfigShim;
18use crate::pipeline::sampling::sample_and_add_toks;
19use crate::pipeline::EmbeddingLoaderType;
20use crate::pipeline::EmbeddingModel;
21use crate::pipeline::EmbeddingModelLoader;
22use crate::pipeline::{AutoEmbeddingLoader, EmbeddingModulePaths};
23use crate::pipeline::{ChatTemplate, EmbeddingModelPaths, IsqOrganization, Processor};
24use crate::pipeline::{EmbeddingGemmaLoader, Qwen3EmbeddingLoader};
25use crate::prefix_cacher::PrefixCacheManagerV2;
26use crate::sequence::Sequence;
27use crate::utils::tokenizer::get_tokenizer;
28use crate::utils::{
29 progress::{new_multi_progress, ProgressScopeGuard},
30 tokens::get_token,
31 varbuilder_utils::from_mmaped_safetensors,
32};
33use crate::Modalities;
34use crate::SupportedModality;
35use crate::{
36 api_get_file, get_uqff_paths, DeviceMapSetting, PagedAttentionConfig, Pipeline, Topology,
37 TryIntoDType, GLOBAL_HF_CACHE,
38};
39use anyhow::Context;
40use anyhow::Result;
41use candle_core::{Device, Tensor};
42use candle_nn::{Linear, Module};
43use hf_hub::Cache;
44use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
45use mistralrs_quant::log::once_log_info;
46use mistralrs_quant::safetensors::MmapedSafetensors;
47use mistralrs_quant::{
48 AfqLayer, GgufMatMul, HqqLayer, ImmediateIsqOverride, IsqType, QuantizedSerdeType,
49};
50use rand_isaac::Isaac64Rng;
51use std::any::Any;
52use std::borrow::Cow;
53use std::env;
54use std::path::{Path, PathBuf};
55use std::str::FromStr;
56use std::sync::{Arc, RwLock};
57use tokenizers::Tokenizer;
58use tokio::sync::Mutex;
59use tracing::{info, warn};
60
61pub struct EmbeddingPipeline {
62 model: Box<dyn EmbeddingModel + Send + Sync>,
63 tokenizer: Arc<Tokenizer>,
64 model_id: String,
65 metadata: Arc<GeneralMetadata>,
66 topology: Option<Topology>,
67 silent: bool,
68 config: String,
69 modules_ser: String,
70 modules_manifest: Vec<EmbeddingModulePaths>,
71 mapper: Box<dyn DeviceMapper + Send + Sync>,
72 modules: Vec<Box<dyn Module + Send + Sync>>,
73 processor: Arc<dyn Processor + Send + Sync>,
74}
75
76pub struct EmbeddingLoader {
78 inner: Box<dyn EmbeddingModelLoader>,
79 model_id: String,
80 config: EmbeddingSpecificConfig,
81 kind: ModelKind,
82 tokenizer_json: Option<String>,
83 token_source: RwLock<Option<TokenSource>>,
84 revision: RwLock<Option<String>>,
85 from_uqff: RwLock<Option<Vec<PathBuf>>>,
86 hf_cache_path: Option<PathBuf>,
87 lora_adapter_ids: Option<Vec<String>>,
88}
89
90#[derive(Default)]
91pub struct EmbeddingLoaderBuilder {
93 model_id: Option<String>,
94 config: EmbeddingSpecificConfig,
95 kind: ModelKind,
96 tokenizer_json: Option<String>,
97 hf_cache_path: Option<PathBuf>,
98 lora_adapter_ids: Option<Vec<String>>,
99}
100
101#[derive(Clone, Default)]
102pub struct EmbeddingSpecificConfig {
104 pub topology: Option<Topology>,
105 pub write_uqff: Option<PathBuf>,
106 pub from_uqff: Option<Vec<PathBuf>>,
107 pub hf_cache_path: Option<PathBuf>,
108}
109
110impl EmbeddingLoaderBuilder {
111 pub fn new(
112 config: EmbeddingSpecificConfig,
113 tokenizer_json: Option<String>,
114 model_id: Option<String>,
115 ) -> Self {
116 Self {
117 config,
118 tokenizer_json,
119 model_id,
120 kind: ModelKind::Normal,
121 hf_cache_path: None,
122 ..Default::default()
123 }
124 }
125
126 pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
127 self.hf_cache_path = Some(hf_cache_path);
128 self
129 }
130
131 pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
132 self.kind = ModelKind::Adapter {
133 adapter: AdapterKind::Lora,
134 };
135 self.lora_adapter_ids = Some(lora_adapter_ids);
136 self
137 }
138
139 pub fn build(self, loader: Option<EmbeddingLoaderType>) -> Box<dyn Loader> {
140 let loader: Box<dyn EmbeddingModelLoader> = match loader {
141 Some(EmbeddingLoaderType::EmbeddingGemma) => Box::new(EmbeddingGemmaLoader),
142 Some(EmbeddingLoaderType::Qwen3Embedding) => Box::new(Qwen3EmbeddingLoader),
143 None => Box::new(AutoEmbeddingLoader),
144 };
145 Box::new(EmbeddingLoader {
146 inner: loader,
147 model_id: self.model_id.unwrap(),
148 config: self.config,
149 kind: self.kind,
150 tokenizer_json: self.tokenizer_json,
151 token_source: RwLock::new(None),
152 revision: RwLock::new(None),
153 from_uqff: RwLock::new(None),
154 hf_cache_path: self.hf_cache_path,
155 lora_adapter_ids: self.lora_adapter_ids,
156 })
157 }
158}
159
160impl Loader for EmbeddingLoader {
161 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
162 fn load_model_from_hf(
163 &self,
164 revision: Option<String>,
165 token_source: TokenSource,
166 dtype: &dyn TryIntoDType,
167 device: &Device,
168 silent: bool,
169 mapper: DeviceMapSetting,
170 in_situ_quant: Option<IsqType>,
171 paged_attn_config: Option<PagedAttentionConfig>,
172 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
173 let _progress_guard = ProgressScopeGuard::new(silent);
174 let cache = self
175 .hf_cache_path
176 .clone()
177 .map(Cache::new)
178 .unwrap_or_default();
179 GLOBAL_HF_CACHE.get_or_init(|| cache);
180
181 let paths: anyhow::Result<Box<dyn ModelPaths>> = get_embedding_paths!(
182 EmbeddingModelPaths,
183 &token_source,
184 revision.clone(),
185 self,
186 None,
187 None,
188 silent,
189 self.config.from_uqff.is_some()
190 );
191 if let Some(from_uqff) = self.config.from_uqff.clone() {
192 *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
193 }
194 *self
195 .token_source
196 .write()
197 .expect("Failed to write to token source") = Some(token_source);
198 *self.revision.write().expect("Failed to write to revision") = revision;
199 self.load_model_from_path(
200 &paths?,
201 dtype,
202 device,
203 silent,
204 mapper,
205 in_situ_quant,
206 paged_attn_config,
207 )
208 }
209
210 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
211 fn load_model_from_path(
212 &self,
213 paths: &Box<dyn ModelPaths>,
214 dtype: &dyn TryIntoDType,
215 device: &Device,
216 silent: bool,
217 mut mapper: DeviceMapSetting,
218 mut in_situ_quant: Option<IsqType>,
219 mut paged_attn_config: Option<PagedAttentionConfig>,
220 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
221 let _progress_guard = ProgressScopeGuard::new(silent);
222 let config = std::fs::read_to_string(paths.get_config_filename())?;
223
224 if paged_attn_config.is_some() {
225 warn!("PagedAttention is not supported for embedding models, disabling it.");
226 paged_attn_config = None;
227 }
228
229 info!("Prompt chunk size is {ATTENTION_CHUNK_SIZE}.");
230
231 let use_nccl = mistralrs_quant::distributed::use_nccl();
232
233 let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
234 let payload: WorkerTransferData = serde_json::from_str(&payload)?;
235 let WorkerTransferData::Init { id: _, worker_rank } = payload;
236 vec![candle_core::Device::new_cuda_with_stream(worker_rank + 1)?]
237 } else if use_nccl {
238 vec![candle_core::Device::new_cuda_with_stream(0)?]
239 } else {
240 device_map::get_all_similar_devices(device)?
241 };
242 #[cfg(feature = "cuda")]
243 for device in &available_devices {
244 if let Device::Cuda(dev) = device {
245 unsafe { dev.disable_event_tracking() };
246 }
247 }
248 let device = if use_nccl {
249 available_devices[0].clone()
250 } else {
251 device.clone()
252 };
253
254 if use_nccl {
256 mapper = DeviceMapSetting::DummyNccl {
257 nm_device: available_devices[0].clone(),
258 };
259 } else if let DeviceMapSetting::Auto(params) = mapper.clone() {
260 let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
262
263 if QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)? != 1 {
265 in_situ_quant = None;
266 }
267
268 let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
271 if let Some(serialized) = &*self.from_uqff.read().unwrap() {
272 let weight_pack_factor = {
273 let ser_artifacts = unsafe {
274 candle_core::safetensors::MmapedSafetensors::multi(serialized)?
275 };
276 let mut total_pack_factors = 0;
277 let total_tensors = ser_artifacts.tensors().len();
278 for (_, artifact) in ser_artifacts.tensors() {
279 let artifact = artifact.data();
280 let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
282 let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
283 {
284 QuantizedSerdeType::Hqq => {
285 HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
286 .pack_factor(dtype)
287 }
288 QuantizedSerdeType::Gguf => {
289 GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
290 .pack_factor(dtype)
291 }
292 QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
293 QuantizedSerdeType::Unquant => 1,
294 QuantizedSerdeType::Afq => {
295 AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
296 .pack_factor(dtype)
297 }
298 };
299 total_pack_factors += pack_factor;
300 }
301
302 total_pack_factors / total_tensors
303 };
304
305 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
306 &config,
307 dtype,
308 weight_pack_factor,
309 None,
310 )?;
311 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
312 &config,
313 dtype,
314 weight_pack_factor,
315 None,
316 )?;
317 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
318 (
319 layer_sizes_in_bytes,
320 non_mapped_size_in_bytes,
321 layer_sizes_sum + non_mapped_size_in_bytes,
322 )
323 } else if let Some(isq) = in_situ_quant {
324 let weight_pack_factor = isq.pack_factor(dtype);
325 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
326 &config,
327 dtype,
328 weight_pack_factor,
329 None,
330 )?;
331 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
332 &config,
333 dtype,
334 weight_pack_factor,
335 None,
336 )?;
337 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
338 (
339 layer_sizes_in_bytes,
340 non_mapped_size_in_bytes,
341 layer_sizes_sum + non_mapped_size_in_bytes,
342 )
343 } else {
344 let weight_pack_factor =
346 QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)?;
347 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
348 &config,
349 dtype,
350 weight_pack_factor,
351 None,
352 )?;
353 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
354 &config,
355 dtype,
356 weight_pack_factor,
357 None,
358 )?;
359 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
360 (
361 layer_sizes_in_bytes,
362 non_mapped_size_in_bytes,
363 layer_sizes_sum + non_mapped_size_in_bytes,
364 )
365 };
366
367 let new = auto_device_map::get_device_layers(
368 &*self.inner,
369 &config,
370 self.inner.num_layers(&config)?,
371 layer_sizes_in_bytes,
372 non_mapped_size_in_bytes,
373 total_model_size_in_bytes,
374 &available_devices,
375 dtype,
376 ¶ms,
377 paged_attn_config.as_ref(),
378 )?;
379 mapper = DeviceMapSetting::Map(new);
380 }
381
382 let pipeline_mapper = mapper.into_mapper(
383 self.inner.num_layers(&config)?,
384 &device,
385 self.config.topology.as_ref(),
386 )?;
387 let mapper = mapper.into_mapper(
388 self.inner.num_layers(&config)?,
389 &device,
390 self.config.topology.as_ref(),
391 )?;
392 let mut layer_devices = Vec::new();
393 for layer in 0..self.inner.num_layers(&config)? {
394 let device = mapper.device_for(layer, false).cloned();
395 layer_devices.push(device);
396 }
397 let dtype = mapper.get_min_dtype(dtype)?;
398
399 info!("Model config: {:?}", self.inner.get_config_repr(&config)?);
400 if crate::using_flash_attn() {
401 once_log_info("FlashAttention is enabled.");
402 }
403
404 let topology_overrides = self
405 .config
406 .topology
407 .as_ref()
408 .map(|topology| {
409 topology
410 .pattern_overrides()
411 .into_iter()
412 .map(|(regex, layer)| ImmediateIsqOverride {
413 predicate: regex,
414 ty: layer.isq,
415 device: layer.device.clone(),
416 })
417 .collect::<Vec<_>>()
418 })
419 .unwrap_or_default();
420 let has_override_isq = topology_overrides
421 .iter()
422 .any(|override_entry| override_entry.ty.is_some());
423 let topology_requires_post_quant = self
424 .config
425 .topology
426 .as_ref()
427 .is_some_and(|topology| topology.requires_post_quantization());
428
429 let allow_immediate_cli = !device.is_cuda() && in_situ_quant.is_some();
430
431 let mut immediate_ty = None;
432 let mut immediate_predicates = Vec::new();
433 if allow_immediate_cli {
434 immediate_ty = in_situ_quant;
435 immediate_predicates = self.inner.immediate_isq_predicates(&config)?;
436 info!("Applying ISQ to {in_situ_quant:?}");
437 if immediate_predicates.is_empty() {
438 warn!("No predicates for this model and ISQ setting detected. ISQ will not be applied to any weights!");
439 }
440 }
441
442 let use_immediate = allow_immediate_cli || has_override_isq;
443 if use_immediate {
444 mistralrs_quant::set_immediate_isq_with_overrides(
445 immediate_ty,
446 immediate_predicates.clone(),
447 topology_overrides.clone(),
448 );
449 }
450
451 let mut loading_isq = if use_immediate {
453 false
454 } else {
455 in_situ_quant.is_some()
456 };
457 loading_isq |= topology_requires_post_quant;
458
459 let load_device = if !loading_isq {
461 loading_isq = false;
462 device.clone()
463 } else {
464 Device::Cpu
465 };
466
467 let attention_mechanism = if paged_attn_config.is_some() {
468 AttentionImplementation::PagedAttention
469 } else {
470 AttentionImplementation::Eager
471 };
472
473 let multi_progress = Arc::new(new_multi_progress());
474
475 let modules_config: Vec<_> = paths
476 .get_modules()
477 .context("Embedding models require the `modules.json` file.")?
478 .to_vec();
479 assert!(matches!(
480 modules_config.first(),
481 Some(EmbeddingModulePaths::Transformer { .. })
482 ));
483
484 let mut modules: Vec<Box<dyn Module + Send + Sync>> = Vec::new();
485 for module in &modules_config {
486 match module {
487 EmbeddingModulePaths::Transformer { .. } => (),
488 EmbeddingModulePaths::Pooling { config, .. } => {
489 let layer: Pooling = serde_json::from_str(&std::fs::read_to_string(config)?)?;
490 modules.push(Box::new(layer));
491 }
492 EmbeddingModulePaths::Dense { config, model, .. } => {
493 let config: Dense = serde_json::from_str(&std::fs::read_to_string(config)?)?;
494 let safetensors = unsafe { MmapedSafetensors::new(model)? };
495 let weight = safetensors.load("linear.weight", &device, Some(dtype))?;
496 let bias = if config.bias {
497 Some(safetensors.load("linear.bias", &device, Some(dtype))?)
498 } else {
499 None
500 };
501 let (out_f, in_f) = weight.dims2()?;
502 assert_eq!((out_f, in_f), (config.out_features, config.in_features));
503 if !matches!(config.activation_function, DenseActivation::Identity) {
504 anyhow::bail!("Expected Identity activation function.");
505 }
506
507 modules.push(Box::new(Linear::new(weight, bias)));
508 }
509 EmbeddingModulePaths::Normalize { .. } => {
510 modules.push(Box::new(Normalize));
511 }
512 }
513 }
514 let modules_ser = EmbeddingModulePaths::serialize_modules(&modules_config);
515
516 let mut model = if use_nccl {
517 let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
518 dtype,
519 &device,
520 &available_devices,
521 silent,
522 &config,
523 loading_isq,
524 self.config.from_uqff.is_some(),
525 IsqOrganization::Default,
526 &*self.inner,
527 paths.as_ref(),
528 )?;
529
530 match self.kind {
532 ModelKind::Normal => embedding_normal_model_loader_sharded!(
533 sharded_vb,
534 config,
535 self.inner,
536 mapper,
537 loading_isq,
538 device.clone(),
539 attention_mechanism,
540 multi_progress.clone(),
541 ),
542 _ => unreachable!(),
543 }
544 } else {
545 match self.kind {
546 ModelKind::Normal => embedding_normal_model_loader!(
547 paths,
548 Some(dtype),
549 &load_device,
550 layer_devices.clone(),
551 config,
552 self.inner,
553 silent,
554 mapper,
555 loading_isq,
556 self.config.from_uqff.is_some(),
557 device.clone(),
558 attention_mechanism,
559 multi_progress,
560 ),
561 _ => unreachable!(),
562 }
563 };
564
565 let tokenizer = get_tokenizer(paths.get_tokenizer_filename(), None)?;
566
567 let should_serialize = self.config.write_uqff.is_some();
568 let should_quantize_pass = loading_isq;
569
570 if (should_quantize_pass || should_serialize) && self.config.from_uqff.is_none() {
571 if should_quantize_pass {
572 info!("Applying ISQ to all ranks.");
573 } else {
574 info!("Serializing existing ISQ tensors without additional quantization.");
575 }
576 model.quantize(
577 in_situ_quant,
578 device.clone(),
579 self.config.topology.as_ref(),
580 silent,
581 None,
582 IsqOrganization::Default,
583 should_quantize_pass,
584 self.config.write_uqff.as_ref(),
585 UqffFullSer {
586 tokenizer: &tokenizer,
587 template_filename: paths.get_template_filename(),
588 generation_config: paths.get_gen_conf_filename(),
589 config: config.clone(),
590 processor_filename: paths.get_processor_config(),
591 preprocessor_filename: paths.get_preprocessor_config(),
592 modules: Some(&modules_ser),
593 module_paths: Some(&modules_config),
594 },
595 Arc::new(new_multi_progress()),
596 )?;
597 } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
598 model.load_from_artifacts(
599 device.clone(),
600 self.config.topology.as_ref(),
601 silent,
602 from_uqff,
603 )?;
604 }
605
606 let has_causal_attention = self.inner.has_causal_attention(&config)?;
607 let max_seq_len = self.inner.model_config(&config)?.max_seq_len();
608 Ok(Arc::new(Mutex::new(EmbeddingPipeline {
609 model,
610 tokenizer: tokenizer.into(),
611 model_id: self.model_id.clone(),
612 metadata: Arc::new(GeneralMetadata {
613 max_seq_len,
614 llg_factory: None,
615 is_xlora: false,
616 no_prefix_cache: false,
617 num_hidden_layers: 1, eos_tok: vec![],
619 kind: ModelKind::Normal,
620 no_kv_cache: true, activation_dtype: dtype,
622 sliding_window: None,
623 cache_config: None,
624 cache_engine: None,
625 model_metadata: None,
626 modalities: Modalities {
627 input: vec![SupportedModality::Text],
628 output: vec![SupportedModality::Embedding],
629 },
630 }),
631 topology: self.config.topology.clone(),
632 silent,
633 config,
634 modules_ser,
635 modules_manifest: modules_config,
636 mapper: pipeline_mapper,
637 modules,
638 processor: Arc::new(EmbeddingProcessor {
639 has_causal_attention,
640 }),
641 })))
642 }
643
644 fn get_id(&self) -> String {
645 self.model_id.to_string()
646 }
647
648 fn get_kind(&self) -> ModelKind {
649 self.kind.clone()
650 }
651}
652
653impl PreProcessingMixin for EmbeddingPipeline {
654 fn get_processor(&self) -> Arc<dyn Processor> {
655 self.processor.clone()
656 }
657 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
658 None
659 }
660 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
661 None
662 }
663}
664
665impl IsqPipelineMixin for EmbeddingPipeline {
666 fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
667 let device = self.device().clone();
668 self.model
669 .quantize(
670 Some(dtype),
671 device,
672 self.topology.as_ref(),
673 self.silent,
674 None,
675 IsqOrganization::Default,
676 true,
677 None,
678 UqffFullSer {
679 tokenizer: &self.tokenizer,
680 template_filename: &None,
681 generation_config: None,
682 config: self.config.clone(),
683 processor_filename: &None,
684 preprocessor_filename: &None,
685 modules: Some(&self.modules_ser),
686 module_paths: Some(&self.modules_manifest),
687 },
688 Arc::new(new_multi_progress()),
689 )
690 .map_err(anyhow::Error::msg)
691 }
692}
693
694impl CacheManagerMixin for EmbeddingPipeline {
695 fn clone_in_cache(&self, _seqs: &mut [&mut Sequence]) {}
696 fn clone_out_cache(&self, _seqs: &mut [&mut Sequence]) {}
697 fn set_none_cache(
698 &self,
699 _seqs: &mut [&mut Sequence],
700 _reset_non_granular: bool,
701 _modify_draft_cache: bool,
702 _load_preallocated_cache: bool,
703 ) {
704 }
705 fn cache(&self) -> &EitherCache {
706 unreachable!()
707 }
708}
709
710impl MetadataMixin for EmbeddingPipeline {
711 fn device(&self) -> Device {
712 self.model.device().clone()
713 }
714 fn get_metadata(&self) -> Arc<GeneralMetadata> {
715 self.metadata.clone()
716 }
717 fn name(&self) -> String {
718 self.model_id.clone()
719 }
720 fn reset_non_granular_state(&self) {}
721 fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
722 Some(self.tokenizer.clone())
723 }
724 fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
725 Some(&*self.mapper)
726 }
727}
728
729#[async_trait::async_trait]
730impl Pipeline for EmbeddingPipeline {
731 fn forward_inputs(
732 &mut self,
733 inputs: Box<dyn Any>,
734 _return_raw_logits: bool,
735 ) -> candle_core::Result<ForwardInputsResult> {
736 let ModelInputs {
737 input_ids,
738 flash_meta,
739 } = *inputs.downcast::<ModelInputs>().expect("Downcast failed.");
740
741 let mut xs = self.model.forward(&input_ids, &flash_meta)?;
742 for module in &self.modules {
743 xs = module.forward(&xs)?;
744 }
745
746 Ok(ForwardInputsResult::Embeddings { embeddings: xs })
747 }
748 async fn sample_causal_gen(
749 &self,
750 seqs: &mut [&mut Sequence],
751 logits: Vec<Tensor>,
752 prefix_cacher: &mut PrefixCacheManagerV2,
753 disable_eos_stop: bool,
754 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
755 ) -> Result<(), candle_core::Error> {
756 sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
757 }
758 fn category(&self) -> ModelCategory {
759 ModelCategory::Embedding
760 }
761}
762
763impl AnyMoePipelineMixin for EmbeddingPipeline {}