1use super::isq::UqffFullSer;
2use super::{
3 get_model_paths, get_xlora_paths, AdapterKind, AnyMoePipelineMixin, CacheManagerMixin,
4 EitherCache, ForwardInputsResult, GeneralMetadata, IsqPipelineMixin, Loader, MetadataMixin,
5 ModelCategory, ModelKind, ModelPaths, PreProcessingMixin, TokenSource,
6};
7use crate::attention::ATTENTION_CHUNK_SIZE;
8use crate::device_map::{self, DeviceMapper};
9use crate::distributed::{self, WorkerTransferData};
10use crate::embedding_models::inputs_processor::{EmbeddingProcessor, ModelInputs};
11use crate::embedding_models::{Dense, DenseActivation, Normalize, Pooling};
12use crate::embedding_normal_model_loader;
13use crate::embedding_normal_model_loader_sharded;
14use crate::get_embedding_paths;
15use crate::paged_attention::AttentionImplementation;
16use crate::pipeline::loaders::auto_device_map;
17use crate::pipeline::loaders::QuantizationConfigShim;
18use crate::pipeline::sampling::sample_and_add_toks;
19use crate::pipeline::EmbeddingLoaderType;
20use crate::pipeline::EmbeddingModel;
21use crate::pipeline::EmbeddingModelLoader;
22use crate::pipeline::{AutoEmbeddingLoader, EmbeddingModulePaths};
23use crate::pipeline::{ChatTemplate, EmbeddingModelPaths, IsqOrganization, Processor};
24use crate::pipeline::{EmbeddingGemmaLoader, Qwen3EmbeddingLoader};
25use crate::prefix_cacher::PrefixCacheManagerV2;
26use crate::sequence::Sequence;
27use crate::utils::tokenizer::get_tokenizer;
28use crate::utils::{
29 progress::{new_multi_progress, ProgressScopeGuard},
30 tokens::get_token,
31 varbuilder_utils::from_mmaped_safetensors,
32};
33use crate::Modalities;
34use crate::SupportedModality;
35use crate::{
36 api_get_file, get_uqff_paths, DeviceMapSetting, PagedAttentionConfig, Pipeline, Topology,
37 TryIntoDType, GLOBAL_HF_CACHE,
38};
39use anyhow::Context;
40use anyhow::Result;
41use candle_core::{Device, Tensor};
42use candle_nn::{Linear, Module};
43use hf_hub::Cache;
44use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
45use mistralrs_quant::log::once_log_info;
46use mistralrs_quant::safetensors::MmapedSafetensors;
47use mistralrs_quant::{
48 AfqLayer, GgufMatMul, HqqLayer, ImmediateIsqOverride, IsqType, QuantizedSerdeType,
49};
50use rand_isaac::Isaac64Rng;
51use std::any::Any;
52use std::borrow::Cow;
53use std::env;
54use std::path::{Path, PathBuf};
55use std::str::FromStr;
56use std::sync::{Arc, RwLock};
57use tokenizers::Tokenizer;
58use tokio::sync::Mutex;
59use tracing::{info, warn};
60
61pub struct EmbeddingPipeline {
62 model: Box<dyn EmbeddingModel + Send + Sync>,
63 tokenizer: Arc<Tokenizer>,
64 model_id: String,
65 metadata: Arc<GeneralMetadata>,
66 topology: Option<Topology>,
67 silent: bool,
68 config: String,
69 modules_ser: String,
70 modules_manifest: Vec<EmbeddingModulePaths>,
71 mapper: Box<dyn DeviceMapper + Send + Sync>,
72 modules: Vec<Box<dyn Module + Send + Sync>>,
73 processor: Arc<dyn Processor + Send + Sync>,
74}
75
76pub struct EmbeddingLoader {
78 inner: Box<dyn EmbeddingModelLoader>,
79 model_id: String,
80 config: EmbeddingSpecificConfig,
81 kind: ModelKind,
82 tokenizer_json: Option<String>,
83 token_source: RwLock<Option<TokenSource>>,
84 revision: RwLock<Option<String>>,
85 from_uqff: RwLock<Option<Vec<PathBuf>>>,
86 hf_cache_path: Option<PathBuf>,
87 lora_adapter_ids: Option<Vec<String>>,
88}
89
90#[derive(Default)]
91pub struct EmbeddingLoaderBuilder {
93 model_id: Option<String>,
94 config: EmbeddingSpecificConfig,
95 kind: ModelKind,
96 tokenizer_json: Option<String>,
97 hf_cache_path: Option<PathBuf>,
98 lora_adapter_ids: Option<Vec<String>>,
99}
100
101#[derive(Clone, Default)]
102pub struct EmbeddingSpecificConfig {
104 pub topology: Option<Topology>,
105 pub write_uqff: Option<PathBuf>,
106 pub from_uqff: Option<Vec<PathBuf>>,
107 pub hf_cache_path: Option<PathBuf>,
108}
109
110impl EmbeddingLoaderBuilder {
111 pub fn new(
112 config: EmbeddingSpecificConfig,
113 tokenizer_json: Option<String>,
114 model_id: Option<String>,
115 ) -> Self {
116 Self {
117 config,
118 tokenizer_json,
119 model_id,
120 kind: ModelKind::Normal,
121 hf_cache_path: None,
122 ..Default::default()
123 }
124 }
125
126 pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
127 self.hf_cache_path = Some(hf_cache_path);
128 self
129 }
130
131 pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
132 self.kind = ModelKind::Adapter {
133 adapter: AdapterKind::Lora,
134 };
135 self.lora_adapter_ids = Some(lora_adapter_ids);
136 self
137 }
138
139 pub fn build(self, loader: Option<EmbeddingLoaderType>) -> Box<dyn Loader> {
140 let loader: Box<dyn EmbeddingModelLoader> = match loader {
141 Some(EmbeddingLoaderType::EmbeddingGemma) => Box::new(EmbeddingGemmaLoader),
142 Some(EmbeddingLoaderType::Qwen3Embedding) => Box::new(Qwen3EmbeddingLoader),
143 None => Box::new(AutoEmbeddingLoader),
144 };
145 Box::new(EmbeddingLoader {
146 inner: loader,
147 model_id: self.model_id.unwrap(),
148 config: self.config,
149 kind: self.kind,
150 tokenizer_json: self.tokenizer_json,
151 token_source: RwLock::new(None),
152 revision: RwLock::new(None),
153 from_uqff: RwLock::new(None),
154 hf_cache_path: self.hf_cache_path,
155 lora_adapter_ids: self.lora_adapter_ids,
156 })
157 }
158}
159
160impl Loader for EmbeddingLoader {
161 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
162 fn load_model_from_hf(
163 &self,
164 revision: Option<String>,
165 token_source: TokenSource,
166 dtype: &dyn TryIntoDType,
167 device: &Device,
168 silent: bool,
169 mapper: DeviceMapSetting,
170 in_situ_quant: Option<IsqType>,
171 paged_attn_config: Option<PagedAttentionConfig>,
172 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
173 let _progress_guard = ProgressScopeGuard::new(silent);
174 let cache = self
175 .hf_cache_path
176 .clone()
177 .map(Cache::new)
178 .unwrap_or_default();
179 GLOBAL_HF_CACHE.get_or_init(|| cache);
180
181 let paths: anyhow::Result<Box<dyn ModelPaths>> = get_embedding_paths!(
182 EmbeddingModelPaths,
183 &token_source,
184 revision.clone(),
185 self,
186 None,
187 None,
188 silent,
189 self.config.from_uqff.is_some()
190 );
191 if let Some(from_uqff) = self.config.from_uqff.clone() {
192 *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
193 }
194 *self
195 .token_source
196 .write()
197 .expect("Failed to write to token source") = Some(token_source);
198 *self.revision.write().expect("Failed to write to revision") = revision;
199 self.load_model_from_path(
200 &paths?,
201 dtype,
202 device,
203 silent,
204 mapper,
205 in_situ_quant,
206 paged_attn_config,
207 )
208 }
209
210 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
211 fn load_model_from_path(
212 &self,
213 paths: &Box<dyn ModelPaths>,
214 dtype: &dyn TryIntoDType,
215 device: &Device,
216 silent: bool,
217 mut mapper: DeviceMapSetting,
218 in_situ_quant: Option<IsqType>,
219 mut paged_attn_config: Option<PagedAttentionConfig>,
220 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
221 let _progress_guard = ProgressScopeGuard::new(silent);
222 let config = std::fs::read_to_string(paths.get_config_filename())?;
223
224 if paged_attn_config.is_some() {
225 warn!("PagedAttention is not supported for embedding models, disabling it.");
226 paged_attn_config = None;
227 }
228
229 info!("Prompt chunk size is {ATTENTION_CHUNK_SIZE}.");
230
231 let use_nccl = mistralrs_quant::distributed::use_nccl();
232
233 let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
234 let payload: WorkerTransferData = serde_json::from_str(&payload)?;
235 let WorkerTransferData::Init { id: _, worker_rank } = payload;
236 vec![candle_core::Device::new_cuda_with_stream(worker_rank + 1)?]
237 } else if use_nccl {
238 vec![candle_core::Device::new_cuda_with_stream(0)?]
239 } else {
240 device_map::get_all_similar_devices(device)?
241 };
242 #[cfg(feature = "cuda")]
243 for device in &available_devices {
244 if let Device::Cuda(dev) = device {
245 unsafe { dev.disable_event_tracking() };
246 }
247 }
248 let device = if use_nccl {
249 available_devices[0].clone()
250 } else {
251 device.clone()
252 };
253
254 if use_nccl {
256 mapper = DeviceMapSetting::DummyNccl {
257 nm_device: available_devices[0].clone(),
258 };
259 } else if let DeviceMapSetting::Auto(params) = mapper.clone() {
260 let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
262
263 let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
266 if let Some(serialized) = &*self.from_uqff.read().unwrap() {
267 let weight_pack_factor = {
268 let ser_artifacts = unsafe {
269 candle_core::safetensors::MmapedSafetensors::multi(serialized)?
270 };
271 let mut total_pack_factors = 0;
272 let total_tensors = ser_artifacts.tensors().len();
273 for (_, artifact) in ser_artifacts.tensors() {
274 let artifact = artifact.data();
275 let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
277 let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
278 {
279 QuantizedSerdeType::Hqq => {
280 HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
281 .pack_factor(dtype)
282 }
283 QuantizedSerdeType::Gguf => {
284 GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
285 .pack_factor(dtype)
286 }
287 QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
288 QuantizedSerdeType::Unquant => 1,
289 QuantizedSerdeType::Afq => {
290 AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
291 .pack_factor(dtype)
292 }
293 };
294 total_pack_factors += pack_factor;
295 }
296
297 total_pack_factors / total_tensors
298 };
299
300 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
301 &config,
302 dtype,
303 weight_pack_factor,
304 None,
305 )?;
306 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
307 &config,
308 dtype,
309 weight_pack_factor,
310 None,
311 )?;
312 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
313 (
314 layer_sizes_in_bytes,
315 non_mapped_size_in_bytes,
316 layer_sizes_sum + non_mapped_size_in_bytes,
317 )
318 } else if let Some(isq) = in_situ_quant {
319 let weight_pack_factor = isq.pack_factor(dtype);
320 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
321 &config,
322 dtype,
323 weight_pack_factor,
324 None,
325 )?;
326 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
327 &config,
328 dtype,
329 weight_pack_factor,
330 None,
331 )?;
332 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
333 (
334 layer_sizes_in_bytes,
335 non_mapped_size_in_bytes,
336 layer_sizes_sum + non_mapped_size_in_bytes,
337 )
338 } else {
339 let weight_pack_factor =
341 QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)?;
342 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
343 &config,
344 dtype,
345 weight_pack_factor,
346 None,
347 )?;
348 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
349 &config,
350 dtype,
351 weight_pack_factor,
352 None,
353 )?;
354 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
355 (
356 layer_sizes_in_bytes,
357 non_mapped_size_in_bytes,
358 layer_sizes_sum + non_mapped_size_in_bytes,
359 )
360 };
361
362 let new = auto_device_map::get_device_layers(
363 &*self.inner,
364 &config,
365 self.inner.num_layers(&config)?,
366 layer_sizes_in_bytes,
367 non_mapped_size_in_bytes,
368 total_model_size_in_bytes,
369 &available_devices,
370 dtype,
371 ¶ms,
372 paged_attn_config.as_ref(),
373 )?;
374 mapper = DeviceMapSetting::Map(new);
375 }
376
377 let pipeline_mapper = mapper.into_mapper(
378 self.inner.num_layers(&config)?,
379 &device,
380 self.config.topology.as_ref(),
381 )?;
382 let mapper = mapper.into_mapper(
383 self.inner.num_layers(&config)?,
384 &device,
385 self.config.topology.as_ref(),
386 )?;
387 let mut layer_devices = Vec::new();
388 for layer in 0..self.inner.num_layers(&config)? {
389 let device = mapper.device_for(layer, false).cloned();
390 layer_devices.push(device);
391 }
392 let dtype = mapper.get_min_dtype(dtype)?;
393
394 info!("Model config: {:?}", self.inner.get_config_repr(&config)?);
395 if crate::using_flash_attn() {
396 once_log_info("FlashAttention is enabled.");
397 }
398
399 let topology_overrides = self
400 .config
401 .topology
402 .as_ref()
403 .map(|topology| {
404 topology
405 .pattern_overrides()
406 .into_iter()
407 .map(|(regex, layer)| ImmediateIsqOverride {
408 predicate: regex,
409 ty: layer.isq,
410 device: layer.device.clone(),
411 })
412 .collect::<Vec<_>>()
413 })
414 .unwrap_or_default();
415 let has_override_isq = topology_overrides
416 .iter()
417 .any(|override_entry| override_entry.ty.is_some());
418 let topology_requires_post_quant = self
419 .config
420 .topology
421 .as_ref()
422 .is_some_and(|topology| topology.requires_post_quantization());
423
424 let allow_immediate_cli = !device.is_cuda() && in_situ_quant.is_some();
425
426 let mut immediate_ty = None;
427 let mut immediate_predicates = Vec::new();
428 if allow_immediate_cli {
429 immediate_ty = in_situ_quant;
430 immediate_predicates = self.inner.immediate_isq_predicates(&config)?;
431 info!("Applying ISQ to {in_situ_quant:?}");
432 if immediate_predicates.is_empty() {
433 warn!("No predicates for this model and ISQ setting detected. ISQ will not be applied to any weights!");
434 }
435 }
436
437 let use_immediate = allow_immediate_cli || has_override_isq;
438 if use_immediate {
439 mistralrs_quant::set_immediate_isq_with_overrides(
440 immediate_ty,
441 immediate_predicates.clone(),
442 topology_overrides.clone(),
443 );
444 }
445
446 let mut loading_isq = if use_immediate {
448 false
449 } else {
450 in_situ_quant.is_some()
451 };
452 loading_isq |= topology_requires_post_quant;
453
454 let load_device = if !loading_isq {
456 loading_isq = false;
457 device.clone()
458 } else {
459 Device::Cpu
460 };
461
462 let attention_mechanism = if paged_attn_config.is_some() {
463 AttentionImplementation::PagedAttention
464 } else {
465 AttentionImplementation::Eager
466 };
467
468 let multi_progress = Arc::new(new_multi_progress());
469
470 let modules_config: Vec<_> = paths
471 .get_modules()
472 .context("Embedding models require the `modules.json` file.")?
473 .to_vec();
474 assert!(matches!(
475 modules_config.first(),
476 Some(EmbeddingModulePaths::Transformer { .. })
477 ));
478
479 let mut modules: Vec<Box<dyn Module + Send + Sync>> = Vec::new();
480 for module in &modules_config {
481 match module {
482 EmbeddingModulePaths::Transformer { .. } => (),
483 EmbeddingModulePaths::Pooling { config, .. } => {
484 let layer: Pooling = serde_json::from_str(&std::fs::read_to_string(config)?)?;
485 modules.push(Box::new(layer));
486 }
487 EmbeddingModulePaths::Dense { config, model, .. } => {
488 let config: Dense = serde_json::from_str(&std::fs::read_to_string(config)?)?;
489 let safetensors = unsafe { MmapedSafetensors::new(model)? };
490 let weight = safetensors.load("linear.weight", &device, Some(dtype))?;
491 let bias = if config.bias {
492 Some(safetensors.load("linear.bias", &device, Some(dtype))?)
493 } else {
494 None
495 };
496 let (out_f, in_f) = weight.dims2()?;
497 assert_eq!((out_f, in_f), (config.out_features, config.in_features));
498 if !matches!(config.activation_function, DenseActivation::Identity) {
499 anyhow::bail!("Expected Identity activation function.");
500 }
501
502 modules.push(Box::new(Linear::new(weight, bias)));
503 }
504 EmbeddingModulePaths::Normalize { .. } => {
505 modules.push(Box::new(Normalize));
506 }
507 }
508 }
509 let modules_ser = EmbeddingModulePaths::serialize_modules(&modules_config);
510
511 let mut model = if use_nccl {
512 let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
513 dtype,
514 &device,
515 &available_devices,
516 silent,
517 &config,
518 loading_isq,
519 self.config.from_uqff.is_some(),
520 IsqOrganization::Default,
521 &*self.inner,
522 paths.as_ref(),
523 )?;
524
525 match self.kind {
527 ModelKind::Normal => embedding_normal_model_loader_sharded!(
528 sharded_vb,
529 config,
530 self.inner,
531 mapper,
532 loading_isq,
533 device.clone(),
534 attention_mechanism,
535 multi_progress.clone(),
536 ),
537 _ => unreachable!(),
538 }
539 } else {
540 match self.kind {
541 ModelKind::Normal => embedding_normal_model_loader!(
542 paths,
543 Some(dtype),
544 &load_device,
545 layer_devices.clone(),
546 config,
547 self.inner,
548 silent,
549 mapper,
550 loading_isq,
551 self.config.from_uqff.is_some(),
552 device.clone(),
553 attention_mechanism,
554 multi_progress,
555 ),
556 _ => unreachable!(),
557 }
558 };
559
560 let tokenizer = get_tokenizer(paths.get_tokenizer_filename(), None)?;
561
562 let should_serialize = self.config.write_uqff.is_some();
563 let should_quantize_pass = loading_isq;
564
565 if (should_quantize_pass || should_serialize) && self.config.from_uqff.is_none() {
566 if should_quantize_pass {
567 info!("Applying ISQ to all ranks.");
568 } else {
569 info!("Serializing existing ISQ tensors without additional quantization.");
570 }
571 model.quantize(
572 in_situ_quant,
573 device.clone(),
574 self.config.topology.as_ref(),
575 silent,
576 None,
577 IsqOrganization::Default,
578 should_quantize_pass,
579 self.config.write_uqff.as_ref(),
580 UqffFullSer {
581 tokenizer: &tokenizer,
582 template_filename: paths.get_template_filename(),
583 generation_config: paths.get_gen_conf_filename(),
584 config: config.clone(),
585 processor_filename: paths.get_processor_config(),
586 preprocessor_filename: paths.get_preprocessor_config(),
587 modules: Some(&modules_ser),
588 module_paths: Some(&modules_config),
589 },
590 Arc::new(new_multi_progress()),
591 )?;
592 } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
593 model.load_from_artifacts(
594 device.clone(),
595 self.config.topology.as_ref(),
596 silent,
597 from_uqff,
598 )?;
599 }
600
601 let has_causal_attention = self.inner.has_causal_attention(&config)?;
602 let max_seq_len = self.inner.model_config(&config)?.max_seq_len();
603 Ok(Arc::new(Mutex::new(EmbeddingPipeline {
604 model,
605 tokenizer: tokenizer.into(),
606 model_id: self.model_id.clone(),
607 metadata: Arc::new(GeneralMetadata {
608 max_seq_len,
609 llg_factory: None,
610 is_xlora: false,
611 no_prefix_cache: false,
612 num_hidden_layers: 1, eos_tok: vec![],
614 kind: ModelKind::Normal,
615 no_kv_cache: true, activation_dtype: dtype,
617 sliding_window: None,
618 cache_config: None,
619 cache_engine: None,
620 model_metadata: None,
621 modalities: Modalities {
622 input: vec![SupportedModality::Text],
623 output: vec![SupportedModality::Embedding],
624 },
625 }),
626 topology: self.config.topology.clone(),
627 silent,
628 config,
629 modules_ser,
630 modules_manifest: modules_config,
631 mapper: pipeline_mapper,
632 modules,
633 processor: Arc::new(EmbeddingProcessor {
634 has_causal_attention,
635 }),
636 })))
637 }
638
639 fn get_id(&self) -> String {
640 self.model_id.to_string()
641 }
642
643 fn get_kind(&self) -> ModelKind {
644 self.kind.clone()
645 }
646}
647
648impl PreProcessingMixin for EmbeddingPipeline {
649 fn get_processor(&self) -> Arc<dyn Processor> {
650 self.processor.clone()
651 }
652 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
653 None
654 }
655 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
656 None
657 }
658}
659
660impl IsqPipelineMixin for EmbeddingPipeline {
661 fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
662 let device = self.device().clone();
663 self.model
664 .quantize(
665 Some(dtype),
666 device,
667 self.topology.as_ref(),
668 self.silent,
669 None,
670 IsqOrganization::Default,
671 true,
672 None,
673 UqffFullSer {
674 tokenizer: &self.tokenizer,
675 template_filename: &None,
676 generation_config: None,
677 config: self.config.clone(),
678 processor_filename: &None,
679 preprocessor_filename: &None,
680 modules: Some(&self.modules_ser),
681 module_paths: Some(&self.modules_manifest),
682 },
683 Arc::new(new_multi_progress()),
684 )
685 .map_err(anyhow::Error::msg)
686 }
687}
688
689impl CacheManagerMixin for EmbeddingPipeline {
690 fn clone_in_cache(&self, _seqs: &mut [&mut Sequence]) {}
691 fn clone_out_cache(&self, _seqs: &mut [&mut Sequence]) {}
692 fn set_none_cache(
693 &self,
694 _seqs: &mut [&mut Sequence],
695 _reset_non_granular: bool,
696 _modify_draft_cache: bool,
697 _load_preallocated_cache: bool,
698 ) {
699 }
700 fn cache(&self) -> &EitherCache {
701 unreachable!()
702 }
703}
704
705impl MetadataMixin for EmbeddingPipeline {
706 fn device(&self) -> Device {
707 self.model.device().clone()
708 }
709 fn get_metadata(&self) -> Arc<GeneralMetadata> {
710 self.metadata.clone()
711 }
712 fn name(&self) -> String {
713 self.model_id.clone()
714 }
715 fn reset_non_granular_state(&self) {}
716 fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
717 Some(self.tokenizer.clone())
718 }
719 fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
720 Some(&*self.mapper)
721 }
722}
723
724#[async_trait::async_trait]
725impl Pipeline for EmbeddingPipeline {
726 fn forward_inputs(
727 &mut self,
728 inputs: Box<dyn Any>,
729 _return_raw_logits: bool,
730 ) -> candle_core::Result<ForwardInputsResult> {
731 let ModelInputs {
732 input_ids,
733 flash_meta,
734 } = *inputs.downcast::<ModelInputs>().expect("Downcast failed.");
735
736 let mut xs = self.model.forward(&input_ids, &flash_meta)?;
737 for module in &self.modules {
738 xs = module.forward(&xs)?;
739 }
740
741 Ok(ForwardInputsResult::Embeddings { embeddings: xs })
742 }
743 async fn sample_causal_gen(
744 &self,
745 seqs: &mut [&mut Sequence],
746 logits: Vec<Tensor>,
747 prefix_cacher: &mut PrefixCacheManagerV2,
748 disable_eos_stop: bool,
749 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
750 ) -> Result<(), candle_core::Error> {
751 sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
752 }
753 fn category(&self) -> ModelCategory {
754 ModelCategory::Embedding
755 }
756}
757
758impl AnyMoePipelineMixin for EmbeddingPipeline {}