mistralrs_core/pipeline/
embedding.rs

1use super::isq::UqffFullSer;
2use super::{
3    get_model_paths, get_xlora_paths, AdapterKind, AnyMoePipelineMixin, CacheManagerMixin,
4    EitherCache, ForwardInputsResult, GeneralMetadata, IsqPipelineMixin, Loader, MetadataMixin,
5    ModelCategory, ModelKind, ModelPaths, PreProcessingMixin, TokenSource,
6};
7use crate::attention::ATTENTION_CHUNK_SIZE;
8use crate::device_map::{self, DeviceMapper};
9use crate::distributed::{self, WorkerTransferData};
10use crate::embedding_models::inputs_processor::{EmbeddingProcessor, ModelInputs};
11use crate::embedding_models::{Dense, DenseActivation, Normalize, Pooling};
12use crate::embedding_normal_model_loader;
13use crate::embedding_normal_model_loader_sharded;
14use crate::get_embedding_paths;
15use crate::paged_attention::AttentionImplementation;
16use crate::pipeline::loaders::auto_device_map;
17use crate::pipeline::loaders::QuantizationConfigShim;
18use crate::pipeline::sampling::sample_and_add_toks;
19use crate::pipeline::EmbeddingLoaderType;
20use crate::pipeline::EmbeddingModel;
21use crate::pipeline::EmbeddingModelLoader;
22use crate::pipeline::{AutoEmbeddingLoader, EmbeddingModulePaths};
23use crate::pipeline::{ChatTemplate, EmbeddingModelPaths, IsqOrganization, Processor};
24use crate::pipeline::{EmbeddingGemmaLoader, Qwen3EmbeddingLoader};
25use crate::prefix_cacher::PrefixCacheManagerV2;
26use crate::sequence::Sequence;
27use crate::utils::tokenizer::get_tokenizer;
28use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
29use crate::Modalities;
30use crate::SupportedModality;
31use crate::{
32    api_get_file, get_uqff_paths, DeviceMapSetting, PagedAttentionConfig, Pipeline, Topology,
33    TryIntoDType, GLOBAL_HF_CACHE,
34};
35use anyhow::Context;
36use anyhow::Result;
37use candle_core::{Device, Tensor};
38use candle_nn::{Linear, Module};
39use hf_hub::Cache;
40use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
41use indicatif::MultiProgress;
42use mistralrs_quant::log::once_log_info;
43use mistralrs_quant::safetensors::MmapedSafetensors;
44use mistralrs_quant::{
45    AfqLayer, GgufMatMul, HqqLayer, ImmediateIsqOverride, IsqType, QuantizedSerdeType,
46};
47use rand_isaac::Isaac64Rng;
48use std::any::Any;
49use std::borrow::Cow;
50use std::env;
51use std::path::{Path, PathBuf};
52use std::str::FromStr;
53use std::sync::{Arc, RwLock};
54use tokenizers::Tokenizer;
55use tokio::sync::Mutex;
56use tracing::{info, warn};
57
58pub struct EmbeddingPipeline {
59    model: Box<dyn EmbeddingModel + Send + Sync>,
60    tokenizer: Arc<Tokenizer>,
61    model_id: String,
62    metadata: Arc<GeneralMetadata>,
63    topology: Option<Topology>,
64    silent: bool,
65    config: String,
66    modules_ser: String,
67    modules_manifest: Vec<EmbeddingModulePaths>,
68    mapper: Box<dyn DeviceMapper + Send + Sync>,
69    modules: Vec<Box<dyn Module + Send + Sync>>,
70    processor: Arc<dyn Processor + Send + Sync>,
71}
72
73/// A loader for a vision (non-quantized) model.
74pub struct EmbeddingLoader {
75    inner: Box<dyn EmbeddingModelLoader>,
76    model_id: String,
77    config: EmbeddingSpecificConfig,
78    kind: ModelKind,
79    tokenizer_json: Option<String>,
80    token_source: RwLock<Option<TokenSource>>,
81    revision: RwLock<Option<String>>,
82    from_uqff: RwLock<Option<Vec<PathBuf>>>,
83    hf_cache_path: Option<PathBuf>,
84    lora_adapter_ids: Option<Vec<String>>,
85}
86
87#[derive(Default)]
88/// A builder for a loader for a vision (non-quantized) model.
89pub struct EmbeddingLoaderBuilder {
90    model_id: Option<String>,
91    config: EmbeddingSpecificConfig,
92    kind: ModelKind,
93    tokenizer_json: Option<String>,
94    hf_cache_path: Option<PathBuf>,
95    lora_adapter_ids: Option<Vec<String>>,
96}
97
98#[derive(Clone, Default)]
99/// Config specific to loading a vision model.
100pub struct EmbeddingSpecificConfig {
101    pub topology: Option<Topology>,
102    pub write_uqff: Option<PathBuf>,
103    pub from_uqff: Option<Vec<PathBuf>>,
104    pub hf_cache_path: Option<PathBuf>,
105}
106
107impl EmbeddingLoaderBuilder {
108    pub fn new(
109        config: EmbeddingSpecificConfig,
110        tokenizer_json: Option<String>,
111        model_id: Option<String>,
112    ) -> Self {
113        Self {
114            config,
115            tokenizer_json,
116            model_id,
117            kind: ModelKind::Normal,
118            hf_cache_path: None,
119            ..Default::default()
120        }
121    }
122
123    pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
124        self.hf_cache_path = Some(hf_cache_path);
125        self
126    }
127
128    pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
129        self.kind = ModelKind::Adapter {
130            adapter: AdapterKind::Lora,
131        };
132        self.lora_adapter_ids = Some(lora_adapter_ids);
133        self
134    }
135
136    pub fn build(self, loader: Option<EmbeddingLoaderType>) -> Box<dyn Loader> {
137        let loader: Box<dyn EmbeddingModelLoader> = match loader {
138            Some(EmbeddingLoaderType::EmbeddingGemma) => Box::new(EmbeddingGemmaLoader),
139            Some(EmbeddingLoaderType::Qwen3Embedding) => Box::new(Qwen3EmbeddingLoader),
140            None => Box::new(AutoEmbeddingLoader),
141        };
142        Box::new(EmbeddingLoader {
143            inner: loader,
144            model_id: self.model_id.unwrap(),
145            config: self.config,
146            kind: self.kind,
147            tokenizer_json: self.tokenizer_json,
148            token_source: RwLock::new(None),
149            revision: RwLock::new(None),
150            from_uqff: RwLock::new(None),
151            hf_cache_path: self.hf_cache_path,
152            lora_adapter_ids: self.lora_adapter_ids,
153        })
154    }
155}
156
157impl Loader for EmbeddingLoader {
158    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
159    fn load_model_from_hf(
160        &self,
161        revision: Option<String>,
162        token_source: TokenSource,
163        dtype: &dyn TryIntoDType,
164        device: &Device,
165        silent: bool,
166        mapper: DeviceMapSetting,
167        in_situ_quant: Option<IsqType>,
168        paged_attn_config: Option<PagedAttentionConfig>,
169    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
170        let cache = self
171            .hf_cache_path
172            .clone()
173            .map(Cache::new)
174            .unwrap_or_default();
175        GLOBAL_HF_CACHE.get_or_init(|| cache);
176
177        let paths: anyhow::Result<Box<dyn ModelPaths>> = get_embedding_paths!(
178            EmbeddingModelPaths,
179            &token_source,
180            revision.clone(),
181            self,
182            None,
183            None,
184            silent,
185            self.config.from_uqff.is_some()
186        );
187        if let Some(from_uqff) = self.config.from_uqff.clone() {
188            *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
189        }
190        *self
191            .token_source
192            .write()
193            .expect("Failed to write to token source") = Some(token_source);
194        *self.revision.write().expect("Failed to write to revision") = revision;
195        self.load_model_from_path(
196            &paths?,
197            dtype,
198            device,
199            silent,
200            mapper,
201            in_situ_quant,
202            paged_attn_config,
203        )
204    }
205
206    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
207    fn load_model_from_path(
208        &self,
209        paths: &Box<dyn ModelPaths>,
210        dtype: &dyn TryIntoDType,
211        device: &Device,
212        silent: bool,
213        mut mapper: DeviceMapSetting,
214        mut in_situ_quant: Option<IsqType>,
215        paged_attn_config: Option<PagedAttentionConfig>,
216    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
217        let config = std::fs::read_to_string(paths.get_config_filename())?;
218
219        if paged_attn_config.is_some() {
220            warn!("PagedAttention is not supported for embedding models, disabling it.");
221        }
222        let _ = paged_attn_config;
223
224        info!("Prompt chunk size is {ATTENTION_CHUNK_SIZE}.");
225
226        let use_nccl = mistralrs_quant::distributed::use_nccl();
227
228        let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
229            let payload: WorkerTransferData = serde_json::from_str(&payload)?;
230            let WorkerTransferData::Init { id: _, worker_rank } = payload;
231            vec![candle_core::Device::new_cuda_with_stream(worker_rank + 1)?]
232        } else if use_nccl {
233            vec![candle_core::Device::new_cuda_with_stream(0)?]
234        } else {
235            device_map::get_all_similar_devices(device)?
236        };
237        #[cfg(feature = "cuda")]
238        for device in &available_devices {
239            if let Device::Cuda(dev) = device {
240                unsafe { dev.disable_event_tracking() };
241            }
242        }
243        let device = if use_nccl {
244            available_devices[0].clone()
245        } else {
246            device.clone()
247        };
248
249        // If auto, convert to Map if not using nccl
250        if use_nccl {
251            mapper = DeviceMapSetting::DummyNccl {
252                nm_device: available_devices[0].clone(),
253            };
254        } else if let DeviceMapSetting::Auto(params) = mapper.clone() {
255            // Initial dtype
256            let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
257
258            // Disable ISQ if we are loading a prequantized model.
259            if QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)? != 1 {
260                in_situ_quant = None;
261            }
262
263            // ISQ or UQFF: quantized path
264            // Match logic below where UQFF has priority
265            let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
266                if let Some(serialized) = &*self.from_uqff.read().unwrap() {
267                    let weight_pack_factor = {
268                        let ser_artifacts = unsafe {
269                            candle_core::safetensors::MmapedSafetensors::multi(serialized)?
270                        };
271                        let mut total_pack_factors = 0;
272                        let total_tensors = ser_artifacts.tensors().len();
273                        for (_, artifact) in ser_artifacts.tensors() {
274                            let artifact = artifact.data();
275                            // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
276                            let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
277                            let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
278                            {
279                                QuantizedSerdeType::Hqq => {
280                                    HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
281                                        .pack_factor(dtype)
282                                }
283                                QuantizedSerdeType::Gguf => {
284                                    GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
285                                        .pack_factor(dtype)
286                                }
287                                QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
288                                QuantizedSerdeType::Unquant => 1,
289                                QuantizedSerdeType::Afq => {
290                                    AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
291                                        .pack_factor(dtype)
292                                }
293                            };
294                            total_pack_factors += pack_factor;
295                        }
296
297                        total_pack_factors / total_tensors
298                    };
299
300                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
301                        &config,
302                        dtype,
303                        weight_pack_factor,
304                        None,
305                    )?;
306                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
307                        &config,
308                        dtype,
309                        weight_pack_factor,
310                        None,
311                    )?;
312                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
313                    (
314                        layer_sizes_in_bytes,
315                        non_mapped_size_in_bytes,
316                        layer_sizes_sum + non_mapped_size_in_bytes,
317                    )
318                } else if let Some(isq) = in_situ_quant {
319                    let weight_pack_factor = isq.pack_factor(dtype);
320                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
321                        &config,
322                        dtype,
323                        weight_pack_factor,
324                        None,
325                    )?;
326                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
327                        &config,
328                        dtype,
329                        weight_pack_factor,
330                        None,
331                    )?;
332                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
333                    (
334                        layer_sizes_in_bytes,
335                        non_mapped_size_in_bytes,
336                        layer_sizes_sum + non_mapped_size_in_bytes,
337                    )
338                } else {
339                    // Be sure to get the weight pack factor here; we might be loading a prequantized model.
340                    let weight_pack_factor =
341                        QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)?;
342                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
343                        &config,
344                        dtype,
345                        weight_pack_factor,
346                        None,
347                    )?;
348                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
349                        &config,
350                        dtype,
351                        weight_pack_factor,
352                        None,
353                    )?;
354                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
355                    (
356                        layer_sizes_in_bytes,
357                        non_mapped_size_in_bytes,
358                        layer_sizes_sum + non_mapped_size_in_bytes,
359                    )
360                };
361
362            let new = auto_device_map::get_device_layers(
363                &*self.inner,
364                &config,
365                self.inner.num_layers(&config)?,
366                layer_sizes_in_bytes,
367                non_mapped_size_in_bytes,
368                total_model_size_in_bytes,
369                &available_devices,
370                dtype,
371                &params,
372                paged_attn_config.as_ref(),
373            )?;
374            mapper = DeviceMapSetting::Map(new);
375        }
376
377        let pipeline_mapper = mapper.into_mapper(
378            self.inner.num_layers(&config)?,
379            &device,
380            self.config.topology.as_ref(),
381        )?;
382        let mapper = mapper.into_mapper(
383            self.inner.num_layers(&config)?,
384            &device,
385            self.config.topology.as_ref(),
386        )?;
387        let mut layer_devices = Vec::new();
388        for layer in 0..self.inner.num_layers(&config)? {
389            let device = mapper.device_for(layer, false).cloned();
390            layer_devices.push(device);
391        }
392        let dtype = mapper.get_min_dtype(dtype)?;
393
394        info!("Model config: {:?}", self.inner.get_config_repr(&config)?);
395        if crate::using_flash_attn() {
396            once_log_info("FlashAttention is enabled.");
397        }
398
399        let topology_overrides = self
400            .config
401            .topology
402            .as_ref()
403            .map(|topology| {
404                topology
405                    .pattern_overrides()
406                    .into_iter()
407                    .map(|(regex, layer)| ImmediateIsqOverride {
408                        predicate: regex,
409                        ty: layer.isq,
410                        device: layer.device.clone(),
411                    })
412                    .collect::<Vec<_>>()
413            })
414            .unwrap_or_default();
415        let has_override_isq = topology_overrides
416            .iter()
417            .any(|override_entry| override_entry.ty.is_some());
418        let topology_requires_post_quant = self
419            .config
420            .topology
421            .as_ref()
422            .is_some_and(|topology| topology.requires_post_quantization());
423
424        let allow_immediate_cli = !device.is_cuda() && in_situ_quant.is_some();
425
426        let mut immediate_ty = None;
427        let mut immediate_predicates = Vec::new();
428        if allow_immediate_cli {
429            immediate_ty = in_situ_quant;
430            immediate_predicates = self.inner.immediate_isq_predicates(&config)?;
431            info!("Applying ISQ to {in_situ_quant:?}");
432            if immediate_predicates.is_empty() {
433                warn!("No predicates for this model and ISQ setting detected. ISQ will not be applied to any weights!");
434            }
435        }
436
437        let use_immediate = allow_immediate_cli || has_override_isq;
438        if use_immediate {
439            mistralrs_quant::set_immediate_isq_with_overrides(
440                immediate_ty,
441                immediate_predicates.clone(),
442                topology_overrides.clone(),
443            );
444        }
445
446        // Logic for ISQ here: if no calibration (i.e imatrix), then allow immediate ISQ. Otherwise, back to normal.
447        let mut loading_isq = if use_immediate {
448            false
449        } else {
450            in_situ_quant.is_some()
451        };
452        loading_isq |= topology_requires_post_quant;
453
454        // Load onto the regular device if not using isq
455        let load_device = if !loading_isq {
456            loading_isq = false;
457            device.clone()
458        } else {
459            Device::Cpu
460        };
461
462        let attention_mechanism = if paged_attn_config.is_some() {
463            AttentionImplementation::PagedAttention
464        } else {
465            AttentionImplementation::Eager
466        };
467
468        let multi_progress = Arc::new(MultiProgress::new());
469
470        let modules_config: Vec<_> = paths
471            .get_modules()
472            .context("Embedding models require the `modules.json` file.")?
473            .to_vec();
474        assert!(matches!(
475            modules_config.first(),
476            Some(EmbeddingModulePaths::Transformer { .. })
477        ));
478
479        let mut modules: Vec<Box<dyn Module + Send + Sync>> = Vec::new();
480        for module in &modules_config {
481            match module {
482                EmbeddingModulePaths::Transformer { .. } => (),
483                EmbeddingModulePaths::Pooling { config, .. } => {
484                    let layer: Pooling = serde_json::from_str(&std::fs::read_to_string(config)?)?;
485                    modules.push(Box::new(layer));
486                }
487                EmbeddingModulePaths::Dense { config, model, .. } => {
488                    let config: Dense = serde_json::from_str(&std::fs::read_to_string(config)?)?;
489                    let safetensors = unsafe { MmapedSafetensors::new(model)? };
490                    let weight = safetensors.load("linear.weight", &device, Some(dtype))?;
491                    let bias = if config.bias {
492                        Some(safetensors.load("linear.bias", &device, Some(dtype))?)
493                    } else {
494                        None
495                    };
496                    let (out_f, in_f) = weight.dims2()?;
497                    assert_eq!((out_f, in_f), (config.out_features, config.in_features));
498                    if !matches!(config.activation_function, DenseActivation::Identity) {
499                        anyhow::bail!("Expected Identity activation function.");
500                    }
501
502                    modules.push(Box::new(Linear::new(weight, bias)));
503                }
504                EmbeddingModulePaths::Normalize { .. } => {
505                    modules.push(Box::new(Normalize));
506                }
507            }
508        }
509        let modules_ser = EmbeddingModulePaths::serialize_modules(&modules_config);
510
511        let mut model = if use_nccl {
512            let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
513                dtype,
514                &device,
515                &available_devices,
516                silent,
517                &config,
518                loading_isq,
519                self.config.from_uqff.is_some(),
520                IsqOrganization::Default,
521                &*self.inner,
522                paths.as_ref(),
523            )?;
524
525            // Special case for where things can be more optimially loaded.
526            match self.kind {
527                ModelKind::Normal => embedding_normal_model_loader_sharded!(
528                    sharded_vb,
529                    config,
530                    self.inner,
531                    mapper,
532                    loading_isq,
533                    device.clone(),
534                    attention_mechanism,
535                    multi_progress.clone(),
536                ),
537                _ => unreachable!(),
538            }
539        } else {
540            match self.kind {
541                ModelKind::Normal => embedding_normal_model_loader!(
542                    paths,
543                    Some(dtype),
544                    &load_device,
545                    layer_devices.clone(),
546                    config,
547                    self.inner,
548                    silent,
549                    mapper,
550                    loading_isq,
551                    self.config.from_uqff.is_some(),
552                    device.clone(),
553                    attention_mechanism,
554                    multi_progress,
555                ),
556                _ => unreachable!(),
557            }
558        };
559
560        let tokenizer = get_tokenizer(paths.get_tokenizer_filename(), None)?;
561
562        let should_serialize = self.config.write_uqff.is_some();
563        let should_quantize_pass = loading_isq;
564
565        if (should_quantize_pass || should_serialize) && self.config.from_uqff.is_none() {
566            if should_quantize_pass {
567                info!("Applying ISQ to all ranks.");
568            } else {
569                info!("Serializing existing ISQ tensors without additional quantization.");
570            }
571            model.quantize(
572                in_situ_quant,
573                device.clone(),
574                self.config.topology.as_ref(),
575                silent,
576                None,
577                IsqOrganization::Default,
578                should_quantize_pass,
579                self.config.write_uqff.as_ref(),
580                UqffFullSer {
581                    tokenizer: &tokenizer,
582                    template_filename: paths.get_template_filename(),
583                    generation_config: paths.get_gen_conf_filename(),
584                    config: config.clone(),
585                    processor_filename: paths.get_processor_config(),
586                    preprocessor_filename: paths.get_preprocessor_config(),
587                    modules: Some(&modules_ser),
588                    module_paths: Some(&modules_config),
589                },
590                Arc::new(MultiProgress::new()),
591            )?;
592        } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
593            model.load_from_artifacts(
594                device.clone(),
595                self.config.topology.as_ref(),
596                silent,
597                from_uqff,
598            )?;
599        }
600
601        let has_causal_attention = self.inner.has_causal_attention(&config)?;
602        let max_seq_len = self.inner.model_config(&config)?.max_seq_len();
603        Ok(Arc::new(Mutex::new(EmbeddingPipeline {
604            model,
605            tokenizer: tokenizer.into(),
606            model_id: self.model_id.clone(),
607            metadata: Arc::new(GeneralMetadata {
608                max_seq_len,
609                llg_factory: None,
610                is_xlora: false,
611                no_prefix_cache: false,
612                num_hidden_layers: 1, // FIXME(EricLBuehler): we know this is only for caching, so its OK.
613                eos_tok: vec![],
614                kind: ModelKind::Normal,
615                no_kv_cache: true, // NOTE(EricLBuehler): no cache for these.
616                activation_dtype: dtype,
617                sliding_window: None,
618                cache_config: None,
619                cache_engine: None,
620                model_metadata: None,
621                modalities: Modalities {
622                    input: vec![SupportedModality::Text],
623                    output: vec![SupportedModality::Embedding],
624                },
625            }),
626            topology: self.config.topology.clone(),
627            silent,
628            config,
629            modules_ser,
630            modules_manifest: modules_config,
631            mapper: pipeline_mapper,
632            modules,
633            processor: Arc::new(EmbeddingProcessor {
634                has_causal_attention,
635            }),
636        })))
637    }
638
639    fn get_id(&self) -> String {
640        self.model_id.to_string()
641    }
642
643    fn get_kind(&self) -> ModelKind {
644        self.kind.clone()
645    }
646}
647
648impl PreProcessingMixin for EmbeddingPipeline {
649    fn get_processor(&self) -> Arc<dyn Processor> {
650        self.processor.clone()
651    }
652    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
653        None
654    }
655    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
656        None
657    }
658}
659
660impl IsqPipelineMixin for EmbeddingPipeline {
661    fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
662        let device = self.device().clone();
663        self.model
664            .quantize(
665                Some(dtype),
666                device,
667                self.topology.as_ref(),
668                self.silent,
669                None,
670                IsqOrganization::Default,
671                true,
672                None,
673                UqffFullSer {
674                    tokenizer: &self.tokenizer,
675                    template_filename: &None,
676                    generation_config: None,
677                    config: self.config.clone(),
678                    processor_filename: &None,
679                    preprocessor_filename: &None,
680                    modules: Some(&self.modules_ser),
681                    module_paths: Some(&self.modules_manifest),
682                },
683                Arc::new(MultiProgress::new()),
684            )
685            .map_err(anyhow::Error::msg)
686    }
687}
688
689impl CacheManagerMixin for EmbeddingPipeline {
690    fn clone_in_cache(&self, _seqs: &mut [&mut Sequence]) {}
691    fn clone_out_cache(&self, _seqs: &mut [&mut Sequence]) {}
692    fn set_none_cache(
693        &self,
694        _seqs: &mut [&mut Sequence],
695        _reset_non_granular: bool,
696        _modify_draft_cache: bool,
697        _load_preallocated_cache: bool,
698    ) {
699    }
700    fn cache(&self) -> &EitherCache {
701        unreachable!()
702    }
703}
704
705impl MetadataMixin for EmbeddingPipeline {
706    fn device(&self) -> Device {
707        self.model.device().clone()
708    }
709    fn get_metadata(&self) -> Arc<GeneralMetadata> {
710        self.metadata.clone()
711    }
712    fn name(&self) -> String {
713        self.model_id.clone()
714    }
715    fn reset_non_granular_state(&self) {}
716    fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
717        Some(self.tokenizer.clone())
718    }
719    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
720        Some(&*self.mapper)
721    }
722}
723
724#[async_trait::async_trait]
725impl Pipeline for EmbeddingPipeline {
726    fn forward_inputs(
727        &mut self,
728        inputs: Box<dyn Any>,
729        _return_raw_logits: bool,
730    ) -> candle_core::Result<ForwardInputsResult> {
731        let ModelInputs {
732            input_ids,
733            flash_meta,
734        } = *inputs.downcast::<ModelInputs>().expect("Downcast failed.");
735
736        let mut xs = self.model.forward(&input_ids, &flash_meta)?;
737        for module in &self.modules {
738            xs = module.forward(&xs)?;
739        }
740
741        Ok(ForwardInputsResult::Embeddings { embeddings: xs })
742    }
743    async fn sample_causal_gen(
744        &self,
745        seqs: &mut [&mut Sequence],
746        logits: Vec<Tensor>,
747        prefix_cacher: &mut PrefixCacheManagerV2,
748        disable_eos_stop: bool,
749        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
750    ) -> Result<(), candle_core::Error> {
751        sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
752    }
753    fn category(&self) -> ModelCategory {
754        ModelCategory::Embedding
755    }
756}
757
758impl AnyMoePipelineMixin for EmbeddingPipeline {}