mistralrs_core/pipeline/
embedding.rs

1use super::isq::UqffFullSer;
2use super::{
3    get_model_paths, get_xlora_paths, AdapterKind, AnyMoePipelineMixin, CacheManagerMixin,
4    EitherCache, ForwardInputsResult, GeneralMetadata, IsqPipelineMixin, Loader, MetadataMixin,
5    ModelCategory, ModelKind, ModelPaths, PreProcessingMixin, TokenSource,
6};
7use crate::attention::ATTENTION_CHUNK_SIZE;
8use crate::device_map::{self, DeviceMapper};
9use crate::distributed::{self, WorkerTransferData};
10use crate::embedding_models::inputs_processor::{EmbeddingProcessor, ModelInputs};
11use crate::embedding_models::{Dense, DenseActivation, Normalize, Pooling};
12use crate::embedding_normal_model_loader;
13use crate::embedding_normal_model_loader_sharded;
14use crate::get_embedding_paths;
15use crate::paged_attention::AttentionImplementation;
16use crate::pipeline::loaders::auto_device_map;
17use crate::pipeline::loaders::QuantizationConfigShim;
18use crate::pipeline::sampling::sample_and_add_toks;
19use crate::pipeline::EmbeddingLoaderType;
20use crate::pipeline::EmbeddingModel;
21use crate::pipeline::EmbeddingModelLoader;
22use crate::pipeline::{AutoEmbeddingLoader, EmbeddingModulePaths};
23use crate::pipeline::{ChatTemplate, EmbeddingModelPaths, IsqOrganization, Processor};
24use crate::pipeline::{EmbeddingGemmaLoader, Qwen3EmbeddingLoader};
25use crate::prefix_cacher::PrefixCacheManagerV2;
26use crate::sequence::Sequence;
27use crate::utils::tokenizer::get_tokenizer;
28use crate::utils::{
29    progress::{new_multi_progress, ProgressScopeGuard},
30    tokens::get_token,
31    varbuilder_utils::from_mmaped_safetensors,
32};
33use crate::Modalities;
34use crate::SupportedModality;
35use crate::{
36    api_get_file, get_uqff_paths, DeviceMapSetting, PagedAttentionConfig, Pipeline, Topology,
37    TryIntoDType, GLOBAL_HF_CACHE,
38};
39use anyhow::Context;
40use anyhow::Result;
41use candle_core::{Device, Tensor};
42use candle_nn::{Linear, Module};
43use hf_hub::Cache;
44use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
45use mistralrs_quant::log::once_log_info;
46use mistralrs_quant::safetensors::MmapedSafetensors;
47use mistralrs_quant::{
48    AfqLayer, GgufMatMul, HqqLayer, ImmediateIsqOverride, IsqType, QuantizedSerdeType,
49};
50use rand_isaac::Isaac64Rng;
51use std::any::Any;
52use std::borrow::Cow;
53use std::env;
54use std::path::{Path, PathBuf};
55use std::str::FromStr;
56use std::sync::{Arc, RwLock};
57use tokenizers::Tokenizer;
58use tokio::sync::Mutex;
59use tracing::{info, warn};
60
61pub struct EmbeddingPipeline {
62    model: Box<dyn EmbeddingModel + Send + Sync>,
63    tokenizer: Arc<Tokenizer>,
64    model_id: String,
65    metadata: Arc<GeneralMetadata>,
66    topology: Option<Topology>,
67    silent: bool,
68    config: String,
69    modules_ser: String,
70    modules_manifest: Vec<EmbeddingModulePaths>,
71    mapper: Box<dyn DeviceMapper + Send + Sync>,
72    modules: Vec<Box<dyn Module + Send + Sync>>,
73    processor: Arc<dyn Processor + Send + Sync>,
74}
75
76/// A loader for a vision (non-quantized) model.
77pub struct EmbeddingLoader {
78    inner: Box<dyn EmbeddingModelLoader>,
79    model_id: String,
80    config: EmbeddingSpecificConfig,
81    kind: ModelKind,
82    tokenizer_json: Option<String>,
83    token_source: RwLock<Option<TokenSource>>,
84    revision: RwLock<Option<String>>,
85    from_uqff: RwLock<Option<Vec<PathBuf>>>,
86    hf_cache_path: Option<PathBuf>,
87    lora_adapter_ids: Option<Vec<String>>,
88}
89
90#[derive(Default)]
91/// A builder for a loader for a vision (non-quantized) model.
92pub struct EmbeddingLoaderBuilder {
93    model_id: Option<String>,
94    config: EmbeddingSpecificConfig,
95    kind: ModelKind,
96    tokenizer_json: Option<String>,
97    hf_cache_path: Option<PathBuf>,
98    lora_adapter_ids: Option<Vec<String>>,
99}
100
101#[derive(Clone, Default)]
102/// Config specific to loading a vision model.
103pub struct EmbeddingSpecificConfig {
104    pub topology: Option<Topology>,
105    pub write_uqff: Option<PathBuf>,
106    pub from_uqff: Option<Vec<PathBuf>>,
107    pub hf_cache_path: Option<PathBuf>,
108}
109
110impl EmbeddingLoaderBuilder {
111    pub fn new(
112        config: EmbeddingSpecificConfig,
113        tokenizer_json: Option<String>,
114        model_id: Option<String>,
115    ) -> Self {
116        Self {
117            config,
118            tokenizer_json,
119            model_id,
120            kind: ModelKind::Normal,
121            hf_cache_path: None,
122            ..Default::default()
123        }
124    }
125
126    pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
127        self.hf_cache_path = Some(hf_cache_path);
128        self
129    }
130
131    pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
132        self.kind = ModelKind::Adapter {
133            adapter: AdapterKind::Lora,
134        };
135        self.lora_adapter_ids = Some(lora_adapter_ids);
136        self
137    }
138
139    pub fn build(self, loader: Option<EmbeddingLoaderType>) -> Box<dyn Loader> {
140        let loader: Box<dyn EmbeddingModelLoader> = match loader {
141            Some(EmbeddingLoaderType::EmbeddingGemma) => Box::new(EmbeddingGemmaLoader),
142            Some(EmbeddingLoaderType::Qwen3Embedding) => Box::new(Qwen3EmbeddingLoader),
143            None => Box::new(AutoEmbeddingLoader),
144        };
145        Box::new(EmbeddingLoader {
146            inner: loader,
147            model_id: self.model_id.unwrap(),
148            config: self.config,
149            kind: self.kind,
150            tokenizer_json: self.tokenizer_json,
151            token_source: RwLock::new(None),
152            revision: RwLock::new(None),
153            from_uqff: RwLock::new(None),
154            hf_cache_path: self.hf_cache_path,
155            lora_adapter_ids: self.lora_adapter_ids,
156        })
157    }
158}
159
160impl Loader for EmbeddingLoader {
161    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
162    fn load_model_from_hf(
163        &self,
164        revision: Option<String>,
165        token_source: TokenSource,
166        dtype: &dyn TryIntoDType,
167        device: &Device,
168        silent: bool,
169        mapper: DeviceMapSetting,
170        in_situ_quant: Option<IsqType>,
171        paged_attn_config: Option<PagedAttentionConfig>,
172    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
173        let _progress_guard = ProgressScopeGuard::new(silent);
174        let cache = self
175            .hf_cache_path
176            .clone()
177            .map(Cache::new)
178            .unwrap_or_default();
179        GLOBAL_HF_CACHE.get_or_init(|| cache);
180
181        let paths: anyhow::Result<Box<dyn ModelPaths>> = get_embedding_paths!(
182            EmbeddingModelPaths,
183            &token_source,
184            revision.clone(),
185            self,
186            None,
187            None,
188            silent,
189            self.config.from_uqff.is_some()
190        );
191        if let Some(from_uqff) = self.config.from_uqff.clone() {
192            *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
193        }
194        *self
195            .token_source
196            .write()
197            .expect("Failed to write to token source") = Some(token_source);
198        *self.revision.write().expect("Failed to write to revision") = revision;
199        self.load_model_from_path(
200            &paths?,
201            dtype,
202            device,
203            silent,
204            mapper,
205            in_situ_quant,
206            paged_attn_config,
207        )
208    }
209
210    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
211    fn load_model_from_path(
212        &self,
213        paths: &Box<dyn ModelPaths>,
214        dtype: &dyn TryIntoDType,
215        device: &Device,
216        silent: bool,
217        mut mapper: DeviceMapSetting,
218        mut in_situ_quant: Option<IsqType>,
219        mut paged_attn_config: Option<PagedAttentionConfig>,
220    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
221        let _progress_guard = ProgressScopeGuard::new(silent);
222        let config = std::fs::read_to_string(paths.get_config_filename())?;
223
224        if paged_attn_config.is_some() {
225            warn!("PagedAttention is not supported for embedding models, disabling it.");
226            paged_attn_config = None;
227        }
228
229        info!("Prompt chunk size is {ATTENTION_CHUNK_SIZE}.");
230
231        let use_nccl = mistralrs_quant::distributed::use_nccl();
232
233        let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
234            let payload: WorkerTransferData = serde_json::from_str(&payload)?;
235            let WorkerTransferData::Init { id: _, worker_rank } = payload;
236            vec![candle_core::Device::new_cuda_with_stream(worker_rank + 1)?]
237        } else if use_nccl {
238            vec![candle_core::Device::new_cuda_with_stream(0)?]
239        } else {
240            device_map::get_all_similar_devices(device)?
241        };
242        #[cfg(feature = "cuda")]
243        for device in &available_devices {
244            if let Device::Cuda(dev) = device {
245                unsafe { dev.disable_event_tracking() };
246            }
247        }
248        let device = if use_nccl {
249            available_devices[0].clone()
250        } else {
251            device.clone()
252        };
253
254        // If auto, convert to Map if not using nccl
255        if use_nccl {
256            mapper = DeviceMapSetting::DummyNccl {
257                nm_device: available_devices[0].clone(),
258            };
259        } else if let DeviceMapSetting::Auto(params) = mapper.clone() {
260            // Initial dtype
261            let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
262
263            // Disable ISQ if we are loading a prequantized model.
264            if QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)? != 1 {
265                in_situ_quant = None;
266            }
267
268            // ISQ or UQFF: quantized path
269            // Match logic below where UQFF has priority
270            let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
271                if let Some(serialized) = &*self.from_uqff.read().unwrap() {
272                    let weight_pack_factor = {
273                        let ser_artifacts = unsafe {
274                            candle_core::safetensors::MmapedSafetensors::multi(serialized)?
275                        };
276                        let mut total_pack_factors = 0;
277                        let total_tensors = ser_artifacts.tensors().len();
278                        for (_, artifact) in ser_artifacts.tensors() {
279                            let artifact = artifact.data();
280                            // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
281                            let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
282                            let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
283                            {
284                                QuantizedSerdeType::Hqq => {
285                                    HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
286                                        .pack_factor(dtype)
287                                }
288                                QuantizedSerdeType::Gguf => {
289                                    GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
290                                        .pack_factor(dtype)
291                                }
292                                QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
293                                QuantizedSerdeType::Unquant => 1,
294                                QuantizedSerdeType::Afq => {
295                                    AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
296                                        .pack_factor(dtype)
297                                }
298                            };
299                            total_pack_factors += pack_factor;
300                        }
301
302                        total_pack_factors / total_tensors
303                    };
304
305                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
306                        &config,
307                        dtype,
308                        weight_pack_factor,
309                        None,
310                    )?;
311                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
312                        &config,
313                        dtype,
314                        weight_pack_factor,
315                        None,
316                    )?;
317                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
318                    (
319                        layer_sizes_in_bytes,
320                        non_mapped_size_in_bytes,
321                        layer_sizes_sum + non_mapped_size_in_bytes,
322                    )
323                } else if let Some(isq) = in_situ_quant {
324                    let weight_pack_factor = isq.pack_factor(dtype);
325                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
326                        &config,
327                        dtype,
328                        weight_pack_factor,
329                        None,
330                    )?;
331                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
332                        &config,
333                        dtype,
334                        weight_pack_factor,
335                        None,
336                    )?;
337                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
338                    (
339                        layer_sizes_in_bytes,
340                        non_mapped_size_in_bytes,
341                        layer_sizes_sum + non_mapped_size_in_bytes,
342                    )
343                } else {
344                    // Be sure to get the weight pack factor here; we might be loading a prequantized model.
345                    let weight_pack_factor =
346                        QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)?;
347                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
348                        &config,
349                        dtype,
350                        weight_pack_factor,
351                        None,
352                    )?;
353                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
354                        &config,
355                        dtype,
356                        weight_pack_factor,
357                        None,
358                    )?;
359                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
360                    (
361                        layer_sizes_in_bytes,
362                        non_mapped_size_in_bytes,
363                        layer_sizes_sum + non_mapped_size_in_bytes,
364                    )
365                };
366
367            let new = auto_device_map::get_device_layers(
368                &*self.inner,
369                &config,
370                self.inner.num_layers(&config)?,
371                layer_sizes_in_bytes,
372                non_mapped_size_in_bytes,
373                total_model_size_in_bytes,
374                &available_devices,
375                dtype,
376                &params,
377                paged_attn_config.as_ref(),
378            )?;
379            mapper = DeviceMapSetting::Map(new);
380        }
381
382        let pipeline_mapper = mapper.into_mapper(
383            self.inner.num_layers(&config)?,
384            &device,
385            self.config.topology.as_ref(),
386        )?;
387        let mapper = mapper.into_mapper(
388            self.inner.num_layers(&config)?,
389            &device,
390            self.config.topology.as_ref(),
391        )?;
392        let mut layer_devices = Vec::new();
393        for layer in 0..self.inner.num_layers(&config)? {
394            let device = mapper.device_for(layer, false).cloned();
395            layer_devices.push(device);
396        }
397        let dtype = mapper.get_min_dtype(dtype)?;
398
399        info!("Model config: {:?}", self.inner.get_config_repr(&config)?);
400        if crate::using_flash_attn() {
401            once_log_info("FlashAttention is enabled.");
402        }
403
404        let topology_overrides = self
405            .config
406            .topology
407            .as_ref()
408            .map(|topology| {
409                topology
410                    .pattern_overrides()
411                    .into_iter()
412                    .map(|(regex, layer)| ImmediateIsqOverride {
413                        predicate: regex,
414                        ty: layer.isq,
415                        device: layer.device.clone(),
416                    })
417                    .collect::<Vec<_>>()
418            })
419            .unwrap_or_default();
420        let has_override_isq = topology_overrides
421            .iter()
422            .any(|override_entry| override_entry.ty.is_some());
423        let topology_requires_post_quant = self
424            .config
425            .topology
426            .as_ref()
427            .is_some_and(|topology| topology.requires_post_quantization());
428
429        let allow_immediate_cli = !device.is_cuda() && in_situ_quant.is_some();
430
431        let mut immediate_ty = None;
432        let mut immediate_predicates = Vec::new();
433        if allow_immediate_cli {
434            immediate_ty = in_situ_quant;
435            immediate_predicates = self.inner.immediate_isq_predicates(&config)?;
436            info!("Applying ISQ to {in_situ_quant:?}");
437            if immediate_predicates.is_empty() {
438                warn!("No predicates for this model and ISQ setting detected. ISQ will not be applied to any weights!");
439            }
440        }
441
442        let use_immediate = allow_immediate_cli || has_override_isq;
443        if use_immediate {
444            mistralrs_quant::set_immediate_isq_with_overrides(
445                immediate_ty,
446                immediate_predicates.clone(),
447                topology_overrides.clone(),
448            );
449        }
450
451        // Logic for ISQ here: if no calibration (i.e imatrix), then allow immediate ISQ. Otherwise, back to normal.
452        let mut loading_isq = if use_immediate {
453            false
454        } else {
455            in_situ_quant.is_some()
456        };
457        loading_isq |= topology_requires_post_quant;
458
459        // Load onto the regular device if not using isq
460        let load_device = if !loading_isq {
461            loading_isq = false;
462            device.clone()
463        } else {
464            Device::Cpu
465        };
466
467        let attention_mechanism = if paged_attn_config.is_some() {
468            AttentionImplementation::PagedAttention
469        } else {
470            AttentionImplementation::Eager
471        };
472
473        let multi_progress = Arc::new(new_multi_progress());
474
475        let modules_config: Vec<_> = paths
476            .get_modules()
477            .context("Embedding models require the `modules.json` file.")?
478            .to_vec();
479        assert!(matches!(
480            modules_config.first(),
481            Some(EmbeddingModulePaths::Transformer { .. })
482        ));
483
484        let mut modules: Vec<Box<dyn Module + Send + Sync>> = Vec::new();
485        for module in &modules_config {
486            match module {
487                EmbeddingModulePaths::Transformer { .. } => (),
488                EmbeddingModulePaths::Pooling { config, .. } => {
489                    let layer: Pooling = serde_json::from_str(&std::fs::read_to_string(config)?)?;
490                    modules.push(Box::new(layer));
491                }
492                EmbeddingModulePaths::Dense { config, model, .. } => {
493                    let config: Dense = serde_json::from_str(&std::fs::read_to_string(config)?)?;
494                    let safetensors = unsafe { MmapedSafetensors::new(model)? };
495                    let weight = safetensors.load("linear.weight", &device, Some(dtype))?;
496                    let bias = if config.bias {
497                        Some(safetensors.load("linear.bias", &device, Some(dtype))?)
498                    } else {
499                        None
500                    };
501                    let (out_f, in_f) = weight.dims2()?;
502                    assert_eq!((out_f, in_f), (config.out_features, config.in_features));
503                    if !matches!(config.activation_function, DenseActivation::Identity) {
504                        anyhow::bail!("Expected Identity activation function.");
505                    }
506
507                    modules.push(Box::new(Linear::new(weight, bias)));
508                }
509                EmbeddingModulePaths::Normalize { .. } => {
510                    modules.push(Box::new(Normalize));
511                }
512            }
513        }
514        let modules_ser = EmbeddingModulePaths::serialize_modules(&modules_config);
515
516        let mut model = if use_nccl {
517            let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
518                dtype,
519                &device,
520                &available_devices,
521                silent,
522                &config,
523                loading_isq,
524                self.config.from_uqff.is_some(),
525                IsqOrganization::Default,
526                &*self.inner,
527                paths.as_ref(),
528            )?;
529
530            // Special case for where things can be more optimially loaded.
531            match self.kind {
532                ModelKind::Normal => embedding_normal_model_loader_sharded!(
533                    sharded_vb,
534                    config,
535                    self.inner,
536                    mapper,
537                    loading_isq,
538                    device.clone(),
539                    attention_mechanism,
540                    multi_progress.clone(),
541                ),
542                _ => unreachable!(),
543            }
544        } else {
545            match self.kind {
546                ModelKind::Normal => embedding_normal_model_loader!(
547                    paths,
548                    Some(dtype),
549                    &load_device,
550                    layer_devices.clone(),
551                    config,
552                    self.inner,
553                    silent,
554                    mapper,
555                    loading_isq,
556                    self.config.from_uqff.is_some(),
557                    device.clone(),
558                    attention_mechanism,
559                    multi_progress,
560                ),
561                _ => unreachable!(),
562            }
563        };
564
565        let tokenizer = get_tokenizer(paths.get_tokenizer_filename(), None)?;
566
567        let should_serialize = self.config.write_uqff.is_some();
568        let should_quantize_pass = loading_isq;
569
570        if (should_quantize_pass || should_serialize) && self.config.from_uqff.is_none() {
571            if should_quantize_pass {
572                info!("Applying ISQ to all ranks.");
573            } else {
574                info!("Serializing existing ISQ tensors without additional quantization.");
575            }
576            model.quantize(
577                in_situ_quant,
578                device.clone(),
579                self.config.topology.as_ref(),
580                silent,
581                None,
582                IsqOrganization::Default,
583                should_quantize_pass,
584                self.config.write_uqff.as_ref(),
585                UqffFullSer {
586                    tokenizer: &tokenizer,
587                    template_filename: paths.get_template_filename(),
588                    generation_config: paths.get_gen_conf_filename(),
589                    config: config.clone(),
590                    processor_filename: paths.get_processor_config(),
591                    preprocessor_filename: paths.get_preprocessor_config(),
592                    modules: Some(&modules_ser),
593                    module_paths: Some(&modules_config),
594                },
595                Arc::new(new_multi_progress()),
596            )?;
597        } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
598            model.load_from_artifacts(
599                device.clone(),
600                self.config.topology.as_ref(),
601                silent,
602                from_uqff,
603            )?;
604        }
605
606        let has_causal_attention = self.inner.has_causal_attention(&config)?;
607        let max_seq_len = self.inner.model_config(&config)?.max_seq_len();
608        Ok(Arc::new(Mutex::new(EmbeddingPipeline {
609            model,
610            tokenizer: tokenizer.into(),
611            model_id: self.model_id.clone(),
612            metadata: Arc::new(GeneralMetadata {
613                max_seq_len,
614                llg_factory: None,
615                is_xlora: false,
616                no_prefix_cache: false,
617                num_hidden_layers: 1, // FIXME(EricLBuehler): we know this is only for caching, so its OK.
618                eos_tok: vec![],
619                kind: ModelKind::Normal,
620                no_kv_cache: true, // NOTE(EricLBuehler): no cache for these.
621                activation_dtype: dtype,
622                sliding_window: None,
623                cache_config: None,
624                cache_engine: None,
625                model_metadata: None,
626                modalities: Modalities {
627                    input: vec![SupportedModality::Text],
628                    output: vec![SupportedModality::Embedding],
629                },
630            }),
631            topology: self.config.topology.clone(),
632            silent,
633            config,
634            modules_ser,
635            modules_manifest: modules_config,
636            mapper: pipeline_mapper,
637            modules,
638            processor: Arc::new(EmbeddingProcessor {
639                has_causal_attention,
640            }),
641        })))
642    }
643
644    fn get_id(&self) -> String {
645        self.model_id.to_string()
646    }
647
648    fn get_kind(&self) -> ModelKind {
649        self.kind.clone()
650    }
651}
652
653impl PreProcessingMixin for EmbeddingPipeline {
654    fn get_processor(&self) -> Arc<dyn Processor> {
655        self.processor.clone()
656    }
657    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
658        None
659    }
660    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
661        None
662    }
663}
664
665impl IsqPipelineMixin for EmbeddingPipeline {
666    fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
667        let device = self.device().clone();
668        self.model
669            .quantize(
670                Some(dtype),
671                device,
672                self.topology.as_ref(),
673                self.silent,
674                None,
675                IsqOrganization::Default,
676                true,
677                None,
678                UqffFullSer {
679                    tokenizer: &self.tokenizer,
680                    template_filename: &None,
681                    generation_config: None,
682                    config: self.config.clone(),
683                    processor_filename: &None,
684                    preprocessor_filename: &None,
685                    modules: Some(&self.modules_ser),
686                    module_paths: Some(&self.modules_manifest),
687                },
688                Arc::new(new_multi_progress()),
689            )
690            .map_err(anyhow::Error::msg)
691    }
692}
693
694impl CacheManagerMixin for EmbeddingPipeline {
695    fn clone_in_cache(&self, _seqs: &mut [&mut Sequence]) {}
696    fn clone_out_cache(&self, _seqs: &mut [&mut Sequence]) {}
697    fn set_none_cache(
698        &self,
699        _seqs: &mut [&mut Sequence],
700        _reset_non_granular: bool,
701        _modify_draft_cache: bool,
702        _load_preallocated_cache: bool,
703    ) {
704    }
705    fn cache(&self) -> &EitherCache {
706        unreachable!()
707    }
708}
709
710impl MetadataMixin for EmbeddingPipeline {
711    fn device(&self) -> Device {
712        self.model.device().clone()
713    }
714    fn get_metadata(&self) -> Arc<GeneralMetadata> {
715        self.metadata.clone()
716    }
717    fn name(&self) -> String {
718        self.model_id.clone()
719    }
720    fn reset_non_granular_state(&self) {}
721    fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
722        Some(self.tokenizer.clone())
723    }
724    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
725        Some(&*self.mapper)
726    }
727}
728
729#[async_trait::async_trait]
730impl Pipeline for EmbeddingPipeline {
731    fn forward_inputs(
732        &mut self,
733        inputs: Box<dyn Any>,
734        _return_raw_logits: bool,
735    ) -> candle_core::Result<ForwardInputsResult> {
736        let ModelInputs {
737            input_ids,
738            flash_meta,
739        } = *inputs.downcast::<ModelInputs>().expect("Downcast failed.");
740
741        let mut xs = self.model.forward(&input_ids, &flash_meta)?;
742        for module in &self.modules {
743            xs = module.forward(&xs)?;
744        }
745
746        Ok(ForwardInputsResult::Embeddings { embeddings: xs })
747    }
748    async fn sample_causal_gen(
749        &self,
750        seqs: &mut [&mut Sequence],
751        logits: Vec<Tensor>,
752        prefix_cacher: &mut PrefixCacheManagerV2,
753        disable_eos_stop: bool,
754        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
755    ) -> Result<(), candle_core::Error> {
756        sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
757    }
758    fn category(&self) -> ModelCategory {
759        ModelCategory::Embedding
760    }
761}
762
763impl AnyMoePipelineMixin for EmbeddingPipeline {}