mistralrs_core/pipeline/
vision.rs

1use super::isq::ImatrixDataSource;
2use super::isq::UqffFullSer;
3use super::{
4    get_model_paths, get_xlora_paths, AdapterKind, AnyMoePipelineMixin, AutoVisionLoader,
5    CacheManager, CacheManagerMixin, EitherCache, ForwardInputsResult, Gemma3Loader,
6    GeneralMetadata, IsqPipelineMixin, Loader, MetadataMixin, MiniCpmOLoader, ModelCategory,
7    ModelKind, ModelPaths, MultimodalPromptPrefixer, Phi4MMLoader, PreProcessingMixin, Processor,
8    Qwen2VLLoader, Qwen3VLLoader, TokenSource, VLlama4Loader, VLlamaLoader, VisionModel,
9    VisionModelLoader,
10};
11use super::{
12    Gemma3nLoader, Idefics2Loader, Idefics3Loader, LLaVALoader, LLaVANextLoader, Mistral3Loader,
13    Phi3VLoader, Qwen2_5VLLoader, VisionLoaderType,
14};
15use crate::attention::ATTENTION_CHUNK_SIZE;
16use crate::device_map::{self, DeviceMapper};
17use crate::distributed::{self, WorkerTransferData};
18use crate::kv_cache::{FullCacheManager, NormalCacheManager};
19use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine};
20use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
21use crate::pipeline::llg::build_llg_factory;
22use crate::pipeline::loaders::auto_device_map;
23use crate::pipeline::loaders::QuantizationConfigShim;
24use crate::pipeline::sampling::sample_and_add_toks;
25use crate::pipeline::text_models_inputs_processor::make_prompt_chunk;
26use crate::pipeline::{get_chat_template, ChatTemplate, IsqOrganization, LocalModelPaths};
27use crate::prefix_cacher::PrefixCacheManagerV2;
28use crate::sequence::Sequence;
29use crate::utils::tokenizer::get_tokenizer;
30use crate::utils::varbuilder_utils::DeviceForLoadTensor;
31use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
32use crate::vision_models::preprocessor_config::PreProcessorConfig;
33use crate::vision_models::processor_config::ProcessorConfig;
34use crate::vision_models::ModelInputs;
35use crate::{
36    api_dir_list, api_get_file, get_paths, get_uqff_paths, vision_normal_model_loader,
37    vision_normal_model_loader_sharded, AnyMoeExpertType, DeviceMapSetting, Ordering,
38    PagedAttentionConfig, Pipeline, Topology, TryIntoDType, GLOBAL_HF_CACHE,
39};
40use anyhow::Result;
41use candle_core::{Device, Tensor, Var};
42use hf_hub::Cache;
43use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
44use indicatif::MultiProgress;
45use mistralrs_quant::log::once_log_info;
46use mistralrs_quant::{AfqLayer, GgufMatMul, HqqLayer, IsqType, QuantizedSerdeType};
47use rand_isaac::Isaac64Rng;
48use regex_automata::meta::Regex;
49use std::any::Any;
50use std::borrow::Cow;
51use std::path::{Path, PathBuf};
52use std::str::FromStr;
53use std::sync::{Arc, RwLock};
54use std::time::Instant;
55use std::{env, fs};
56use tokenizers::Tokenizer;
57use tokio::sync::Mutex;
58use tracing::{info, warn};
59
60pub struct VisionPipeline {
61    model: Box<dyn VisionModel + Send + Sync>,
62    tokenizer: Arc<Tokenizer>,
63    chat_template: Arc<ChatTemplate>,
64    model_id: String,
65    metadata: Arc<GeneralMetadata>,
66    processor: Arc<dyn Processor + Send + Sync>,
67    preprocessor_config: Arc<PreProcessorConfig>,
68    topology: Option<Topology>,
69    silent: bool,
70    prefixer: Arc<dyn MultimodalPromptPrefixer>,
71    mapper: Box<dyn DeviceMapper + Send + Sync>,
72
73    // For full UQFF serialization
74    template_filename: Option<PathBuf>,
75    generation_config: Option<PathBuf>,
76    config: String,
77    processor_filename: Option<PathBuf>,
78    preprocessor_filename: Option<PathBuf>,
79    imatrix: Option<PathBuf>,
80}
81
82/// A loader for a vision (non-quantized) model.
83pub struct VisionLoader {
84    inner: Box<dyn VisionModelLoader>,
85    model_id: String,
86    config: VisionSpecificConfig,
87    kind: ModelKind,
88    chat_template: Option<String>,
89    tokenizer_json: Option<String>,
90    xlora_model_id: Option<String>,
91    xlora_order: Option<Ordering>,
92    token_source: RwLock<Option<TokenSource>>,
93    revision: RwLock<Option<String>>,
94    from_uqff: RwLock<Option<Vec<PathBuf>>>,
95    jinja_explicit: Option<String>,
96    hf_cache_path: Option<PathBuf>,
97    lora_adapter_ids: Option<Vec<String>>,
98}
99
100#[derive(Default)]
101/// A builder for a loader for a vision (non-quantized) model.
102pub struct VisionLoaderBuilder {
103    model_id: Option<String>,
104    config: VisionSpecificConfig,
105    kind: ModelKind,
106    chat_template: Option<String>,
107    tokenizer_json: Option<String>,
108    jinja_explicit: Option<String>,
109    hf_cache_path: Option<PathBuf>,
110    lora_adapter_ids: Option<Vec<String>>,
111}
112
113#[derive(Clone, Default)]
114/// Config specific to loading a vision model.
115pub struct VisionSpecificConfig {
116    pub topology: Option<Topology>,
117    pub write_uqff: Option<PathBuf>,
118    pub from_uqff: Option<Vec<PathBuf>>,
119    pub max_edge: Option<u32>,
120    pub imatrix: Option<PathBuf>,
121    pub calibration_file: Option<PathBuf>,
122    pub hf_cache_path: Option<PathBuf>,
123    pub matformer_config_path: Option<PathBuf>,
124    pub matformer_slice_name: Option<String>,
125}
126
127impl VisionLoaderBuilder {
128    pub fn new(
129        config: VisionSpecificConfig,
130        chat_template: Option<String>,
131        tokenizer_json: Option<String>,
132        model_id: Option<String>,
133        jinja_explicit: Option<String>,
134    ) -> Self {
135        Self {
136            config,
137            chat_template,
138            tokenizer_json,
139            model_id,
140            jinja_explicit,
141            kind: ModelKind::Normal,
142            hf_cache_path: None,
143            ..Default::default()
144        }
145    }
146
147    pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
148        self.hf_cache_path = Some(hf_cache_path);
149        self
150    }
151
152    pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
153        self.kind = ModelKind::Adapter {
154            adapter: AdapterKind::Lora,
155        };
156        self.lora_adapter_ids = Some(lora_adapter_ids);
157        self
158    }
159
160    pub fn build(self, loader: Option<VisionLoaderType>) -> Box<dyn Loader> {
161        let loader: Box<dyn VisionModelLoader> = match loader {
162            Some(VisionLoaderType::Phi3V) => Box::new(Phi3VLoader),
163            Some(VisionLoaderType::Idefics2) => Box::new(Idefics2Loader),
164            Some(VisionLoaderType::LLaVANext) => Box::new(LLaVANextLoader),
165            Some(VisionLoaderType::LLaVA) => Box::new(LLaVALoader),
166            Some(VisionLoaderType::VLlama) => Box::new(VLlamaLoader),
167            Some(VisionLoaderType::Qwen2VL) => Box::new(Qwen2VLLoader),
168            Some(VisionLoaderType::Idefics3) => Box::new(Idefics3Loader),
169            Some(VisionLoaderType::MiniCpmO) => Box::new(MiniCpmOLoader),
170            Some(VisionLoaderType::Phi4MM) => Box::new(Phi4MMLoader),
171            Some(VisionLoaderType::Qwen2_5VL) => Box::new(Qwen2_5VLLoader),
172            Some(VisionLoaderType::Gemma3) => Box::new(Gemma3Loader),
173            Some(VisionLoaderType::Mistral3) => Box::new(Mistral3Loader),
174            Some(VisionLoaderType::Llama4) => Box::new(VLlama4Loader),
175            Some(VisionLoaderType::Gemma3n) => Box::new(Gemma3nLoader),
176            Some(VisionLoaderType::Qwen3VL) => Box::new(Qwen3VLLoader),
177            None => Box::new(AutoVisionLoader),
178        };
179        Box::new(VisionLoader {
180            inner: loader,
181            model_id: self.model_id.unwrap(),
182            config: self.config,
183            kind: self.kind,
184            chat_template: self.chat_template,
185            tokenizer_json: self.tokenizer_json,
186            xlora_model_id: None,
187            xlora_order: None,
188            jinja_explicit: self.jinja_explicit,
189            token_source: RwLock::new(None),
190            revision: RwLock::new(None),
191            from_uqff: RwLock::new(None),
192            hf_cache_path: self.hf_cache_path,
193            lora_adapter_ids: self.lora_adapter_ids,
194        })
195    }
196}
197
198impl Loader for VisionLoader {
199    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
200    fn load_model_from_hf(
201        &self,
202        revision: Option<String>,
203        token_source: TokenSource,
204        dtype: &dyn TryIntoDType,
205        device: &Device,
206        silent: bool,
207        mapper: DeviceMapSetting,
208        in_situ_quant: Option<IsqType>,
209        paged_attn_config: Option<PagedAttentionConfig>,
210    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
211        let cache = self
212            .hf_cache_path
213            .clone()
214            .map(Cache::new)
215            .unwrap_or_default();
216        GLOBAL_HF_CACHE.get_or_init(|| cache);
217
218        let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths!(
219            LocalModelPaths,
220            &token_source,
221            revision.clone(),
222            self,
223            None,
224            None,
225            silent,
226            self.config.from_uqff.is_some()
227        );
228        if let Some(from_uqff) = self.config.from_uqff.clone() {
229            *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
230        }
231        *self
232            .token_source
233            .write()
234            .expect("Failed to write to token source") = Some(token_source);
235        *self.revision.write().expect("Failed to write to revision") = revision;
236        self.load_model_from_path(
237            &paths?,
238            dtype,
239            device,
240            silent,
241            mapper,
242            in_situ_quant,
243            paged_attn_config,
244        )
245    }
246
247    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
248    fn load_model_from_path(
249        &self,
250        paths: &Box<dyn ModelPaths>,
251        dtype: &dyn TryIntoDType,
252        device: &Device,
253        silent: bool,
254        mut mapper: DeviceMapSetting,
255        mut in_situ_quant: Option<IsqType>,
256        mut paged_attn_config: Option<PagedAttentionConfig>,
257    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
258        let config = std::fs::read_to_string(paths.get_config_filename())?;
259
260        if !self.inner.supports_paged_attention(&config) {
261            paged_attn_config = None;
262        }
263
264        info!("Prompt chunk size is {ATTENTION_CHUNK_SIZE}.");
265
266        let use_nccl = mistralrs_quant::distributed::use_nccl();
267
268        let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
269            let payload: WorkerTransferData = serde_json::from_str(&payload)?;
270            let WorkerTransferData::Init { id: _, worker_rank } = payload;
271            vec![candle_core::Device::new_cuda_with_stream(worker_rank + 1)?]
272        } else if use_nccl {
273            vec![candle_core::Device::new_cuda_with_stream(0)?]
274        } else {
275            device_map::get_all_similar_devices(device)?
276        };
277        #[cfg(feature = "cuda")]
278        for device in &available_devices {
279            if let Device::Cuda(dev) = device {
280                unsafe { dev.disable_event_tracking() };
281            }
282        }
283        let device = if use_nccl {
284            available_devices[0].clone()
285        } else {
286            device.clone()
287        };
288
289        // Load matformer slicing config if provided
290        let matformer_slicing_config = if let Some(matformer_path) =
291            &self.config.matformer_config_path
292        {
293            use crate::matformer::{MatformerConfig, MatformerSliceConfig};
294            info!("Loading Matformer config from {:?}", matformer_path);
295            let config = Arc::new(MatformerConfig::from_file(matformer_path)?);
296
297            if let Some(slice_name) = &self.config.matformer_slice_name {
298                info!("Using Matformer slice: {}", slice_name);
299                Some(MatformerSliceConfig::new(slice_name.clone(), config))
300            } else {
301                // If no slice name is provided but config exists, we'll need to handle this
302                // For now, return None and let the model handle the default slice selection
303                warn!("Matformer config loaded but no slice name specified. Models will use their default slice.");
304                None
305            }
306        } else {
307            None
308        };
309
310        // If auto, convert to Map if not using nccl
311        if use_nccl {
312            mapper = DeviceMapSetting::DummyNccl {
313                nm_device: available_devices[0].clone(),
314            };
315        } else if let DeviceMapSetting::Auto(mut params) = mapper.clone() {
316            // We can promote to vision params if we get text params
317            params = params.maybe_promote_to_vision();
318
319            // Initial dtype
320            let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
321
322            // Disable ISQ if we are loading a prequantized model.
323            if QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)? != 1 {
324                in_situ_quant = None;
325            }
326
327            // ISQ or UQFF: quantized path
328            // Match logic below where UQFF has priority
329            let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
330                if let Some(serialized) = &*self.from_uqff.read().unwrap() {
331                    let weight_pack_factor = {
332                        let ser_artifacts = unsafe {
333                            candle_core::safetensors::MmapedSafetensors::multi(serialized)?
334                        };
335                        let mut total_pack_factors = 0;
336                        let total_tensors = ser_artifacts.tensors().len();
337                        for (_, artifact) in ser_artifacts.tensors() {
338                            let artifact = artifact.data();
339                            // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
340                            let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
341                            let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
342                            {
343                                QuantizedSerdeType::Hqq => {
344                                    HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
345                                        .pack_factor(dtype)
346                                }
347                                QuantizedSerdeType::Gguf => {
348                                    GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
349                                        .pack_factor(dtype)
350                                }
351                                QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
352                                QuantizedSerdeType::Unquant => 1,
353                                QuantizedSerdeType::Afq => {
354                                    AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
355                                        .pack_factor(dtype)
356                                }
357                            };
358                            total_pack_factors += pack_factor;
359                        }
360
361                        total_pack_factors / total_tensors
362                    };
363
364                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
365                        &config,
366                        dtype,
367                        weight_pack_factor,
368                        matformer_slicing_config.as_ref(),
369                    )?;
370                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
371                        &config,
372                        dtype,
373                        weight_pack_factor,
374                        matformer_slicing_config.as_ref(),
375                    )?;
376                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
377                    (
378                        layer_sizes_in_bytes,
379                        non_mapped_size_in_bytes,
380                        layer_sizes_sum + non_mapped_size_in_bytes,
381                    )
382                } else if let Some(isq) = in_situ_quant {
383                    let weight_pack_factor = isq.pack_factor(dtype);
384                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
385                        &config,
386                        dtype,
387                        weight_pack_factor,
388                        matformer_slicing_config.as_ref(),
389                    )?;
390                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
391                        &config,
392                        dtype,
393                        weight_pack_factor,
394                        matformer_slicing_config.as_ref(),
395                    )?;
396                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
397                    (
398                        layer_sizes_in_bytes,
399                        non_mapped_size_in_bytes,
400                        layer_sizes_sum + non_mapped_size_in_bytes,
401                    )
402                } else {
403                    // Be sure to get the weight pack factor here; we might be loading a prequantized model.
404                    let weight_pack_factor =
405                        QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)?;
406                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
407                        &config,
408                        dtype,
409                        weight_pack_factor,
410                        matformer_slicing_config.as_ref(),
411                    )?;
412                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
413                        &config,
414                        dtype,
415                        weight_pack_factor,
416                        matformer_slicing_config.as_ref(),
417                    )?;
418                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
419                    (
420                        layer_sizes_in_bytes,
421                        non_mapped_size_in_bytes,
422                        layer_sizes_sum + non_mapped_size_in_bytes,
423                    )
424                };
425
426            let new = auto_device_map::get_device_layers(
427                &*self.inner,
428                &config,
429                self.inner.num_layers(&config)?,
430                layer_sizes_in_bytes,
431                non_mapped_size_in_bytes,
432                total_model_size_in_bytes,
433                &available_devices,
434                dtype,
435                &params,
436                paged_attn_config.as_ref(),
437            )?;
438            mapper = DeviceMapSetting::Map(new);
439        }
440
441        let pipeline_mapper = mapper.into_mapper(
442            self.inner.num_layers(&config)?,
443            &device,
444            self.config.topology.as_ref(),
445        )?;
446        let mapper = mapper.into_mapper(
447            self.inner.num_layers(&config)?,
448            &device,
449            self.config.topology.as_ref(),
450        )?;
451        let mut layer_devices = Vec::new();
452        for layer in 0..self.inner.num_layers(&config)? {
453            let device = mapper.device_for(layer, false).cloned();
454            layer_devices.push(device);
455        }
456        let dtype = mapper.get_min_dtype(dtype)?;
457
458        // TODO: PagedAttention is not supported with CPU for now.
459        // This check is not really necessary because `get_device_layers` should prevent it.
460        let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
461        if mapping_uses_cpu && paged_attn_config.is_some() {
462            warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
463            paged_attn_config = None;
464        }
465
466        info!("Model config: {:?}", self.inner.get_config_repr(&config)?);
467        if crate::using_flash_attn() {
468            once_log_info("FlashAttention is enabled.");
469        }
470
471        // Logic for ISQ here: if no calibration (i.e imatrix), then allow immediate ISQ. Otherwise, back to normal.
472        let mut loading_isq = if self.config.imatrix.is_none()
473            && self.config.calibration_file.is_none()
474            && !device.is_cuda()
475            && self.config.write_uqff.is_none()
476            && in_situ_quant.is_some()
477        {
478            let predicates = self.inner.immediate_isq_predicates(&config)?;
479            info!("Applying ISQ to {in_situ_quant:?}");
480            if predicates.is_empty() {
481                warn!("No predicates for this model and ISQ setting detected. ISQ will not be applied to any weights!");
482            }
483            mistralrs_quant::set_immediate_isq(in_situ_quant, predicates);
484            false
485        } else {
486            in_situ_quant.is_some()
487        };
488
489        if let Some(ref topology) = self.config.topology {
490            loading_isq |= topology
491                .0
492                .iter()
493                .any(|layer| layer.as_ref().is_some_and(|layer| layer.isq.is_some()));
494        }
495
496        if self.config.imatrix.is_some() && self.config.calibration_file.is_some() {
497            anyhow::bail!(
498                "`imatrix` and `calibration_file` were both specified, this is not allowed."
499            );
500        }
501
502        // Load onto the regular device if not using isq or if the calibration file is specified
503        let load_device = if !loading_isq || self.config.calibration_file.is_some() {
504            loading_isq = false;
505            device.clone()
506        } else {
507            Device::Cpu
508        };
509
510        let attention_mechanism = if paged_attn_config.is_some() {
511            AttentionImplementation::PagedAttention
512        } else {
513            AttentionImplementation::Eager
514        };
515
516        let multi_progress = Arc::new(MultiProgress::new());
517
518        let mut model = if use_nccl {
519            let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
520                dtype,
521                &device,
522                &available_devices,
523                silent,
524                &config,
525                loading_isq,
526                self.config.from_uqff.is_some(),
527                IsqOrganization::Default,
528                &*self.inner,
529                paths.as_ref(),
530            )?;
531
532            // Special case for where things can be more optimially loaded.
533            match self.kind {
534                ModelKind::Normal => vision_normal_model_loader_sharded!(
535                    sharded_vb,
536                    config,
537                    self.inner,
538                    mapper,
539                    loading_isq,
540                    device.clone(),
541                    attention_mechanism,
542                    multi_progress.clone(),
543                    matformer_slicing_config.clone(),
544                ),
545                _ => unreachable!(),
546            }
547        } else {
548            match self.kind {
549                ModelKind::Normal => vision_normal_model_loader!(
550                    paths,
551                    Some(dtype),
552                    &load_device,
553                    layer_devices.clone(),
554                    config,
555                    self.inner,
556                    silent,
557                    mapper,
558                    loading_isq,
559                    self.config.from_uqff.is_some(),
560                    device.clone(),
561                    attention_mechanism,
562                    multi_progress,
563                    matformer_slicing_config.clone(),
564                ),
565                _ => unreachable!(),
566            }
567        };
568
569        // Handle the Gemma 3 1b case here
570        let preprocessor_config: PreProcessorConfig = match paths.get_preprocessor_config().as_ref()
571        {
572            Some(preprocessor_config) => {
573                serde_json::from_str(&fs::read_to_string(preprocessor_config).unwrap()).unwrap()
574            }
575            None => PreProcessorConfig::default(),
576        };
577        let processor_config: Option<ProcessorConfig> = paths
578            .get_processor_config()
579            .as_ref()
580            .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
581
582        let processor = self.inner.get_processor(
583            &config,
584            processor_config,
585            preprocessor_config.clone(),
586            self.config.max_edge,
587        ); //There are always some repos that don't properly handle config position, for example... LLaVA
588
589        let tokenizer = get_tokenizer(
590            paths.get_tokenizer_filename(),
591            Some(processor.get_special_tokens()),
592        )?;
593
594        let gen_conf: Option<GenerationConfig> = paths
595            .get_gen_conf_filename()
596            .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
597        let chat_template_explicit = paths
598            .get_chat_template_explicit()
599            .as_ref()
600            .map(|x| x.to_string_lossy().to_string());
601        let chat_template = get_chat_template(
602            paths,
603            self.jinja_explicit.as_ref(),
604            chat_template_explicit.as_ref(),
605            self.chat_template.as_ref(),
606            None,
607        );
608
609        if let Some(calibration_file) = &self.config.calibration_file {
610            let calibration_data = std::fs::read_to_string(calibration_file)?;
611            // Tokenize, don't add bos yet
612            let tokens = tokenizer
613                .encode_fast(calibration_data, false)
614                .map_err(anyhow::Error::msg)?
615                .get_ids()
616                .to_vec();
617            info!(
618                "Collecting imatrix from calibration file `{}` of {} tokens.",
619                calibration_file.display(),
620                tokens.len()
621            );
622            let bos_toks = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();
623            let bos_tok_id = tokenizer
624                .token_to_id(&bos_toks[0])
625                .expect("Somehow the bos token is not present.");
626
627            // NOTE: We ONLY calibrate the text bits of these models!!
628            // So only those should be tracked!
629            model.begin_track_stats()?;
630
631            const CHUNK_SIZE: usize = 1024;
632            let n_chunks: usize = tokens.len().div_ceil(CHUNK_SIZE);
633            let start = Instant::now();
634            for (i, chunk) in tokens.chunks(CHUNK_SIZE).enumerate() {
635                let chunk = [vec![bos_tok_id], chunk.to_vec()].concat();
636                let chunk_len = chunk.len();
637
638                let start = Instant::now();
639                let inputs = make_prompt_chunk(
640                    0,
641                    vec![&chunk],
642                    &[0],
643                    &load_device,
644                    None,
645                    false,
646                    None,
647                    None,
648                )?;
649                let _ = model.forward(
650                    &inputs.input,
651                    None, // NOTE: We ONLY calibrate the text bits of these models!!
652                    &inputs.positions,
653                    inputs.context_lens,
654                    inputs.position_ids,
655                    model.default_model_specific_args(&inputs.input),
656                    None,
657                    &inputs.flash_meta,
658                )?;
659                match model.cache_mut() {
660                    EitherCache::Full(full) => {
661                        for layer in &mut *full.lock() {
662                            *layer = None
663                        }
664                    }
665                    EitherCache::Normal(normal) => {
666                        for layer in &mut *normal.lock().unwrap().0 {
667                            layer.reset();
668                        }
669                    }
670                }
671                let end = Instant::now();
672                info!(
673                    "Processed chunk {}/{n_chunks} ({chunk_len} tokens), {:.2}s",
674                    i + 1,
675                    end.duration_since(start).as_secs_f32()
676                );
677            }
678            load_device.synchronize()?;
679            let end = Instant::now();
680            info!(
681                "Finished collecting imatrix in {:.2}s",
682                end.duration_since(start).as_secs_f32()
683            );
684        }
685
686        // Only if loading from UQFF
687        if (loading_isq || self.config.topology.is_some()) && self.config.from_uqff.is_none() {
688            let imatrix_source = match (
689                self.config.imatrix.as_ref(),
690                self.config.calibration_file.is_some(),
691            ) {
692                (None, false) => None,
693                (Some(file), false) => Some(ImatrixDataSource::File(file)),
694                (None, true) => Some(ImatrixDataSource::Collected),
695                (Some(_), true) => unreachable!(),
696            };
697            model.quantize(
698                in_situ_quant,
699                device.clone(),
700                self.config.topology.as_ref(),
701                silent,
702                imatrix_source,
703                IsqOrganization::Default,
704                self.config.write_uqff.as_ref(),
705                UqffFullSer {
706                    tokenizer: &tokenizer,
707                    template_filename: paths.get_template_filename(),
708                    generation_config: paths.get_gen_conf_filename(),
709                    config: config.clone(),
710                    processor_filename: paths.get_processor_config(),
711                    preprocessor_filename: paths.get_preprocessor_config(),
712                },
713                Arc::new(MultiProgress::new()),
714            )?;
715        } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
716            model.load_from_artifacts(
717                device.clone(),
718                self.config.topology.as_ref(),
719                silent,
720                from_uqff,
721            )?;
722        }
723
724        let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
725            anyhow::ensure!(
726                !matches!(self.kind, ModelKind::Adapter { .. }),
727                "PagedAttention does not support adapter models."
728            );
729            let cache_config = calculate_cache_config(
730                paged_attn_config.mem_gpu,
731                paged_attn_config.block_size,
732                dtype,
733                paged_attn_config.cache_type,
734                model.config(),
735                &device,
736                &layer_devices,
737                silent,
738            )?;
739            let cache_engine =
740                CacheEngine::new(model.config(), &cache_config, dtype, &device, layer_devices)?;
741            (Some(cache_config), Some(cache_engine))
742        } else {
743            (None, None)
744        };
745
746        let max_seq_len = model.max_seq_len();
747        let llg_factory = build_llg_factory(tokenizer.clone())?;
748        let num_hidden_layers = match model.cache() {
749            EitherCache::Full(full) => full.lock().len(),
750            EitherCache::Normal(normal) => normal.lock().unwrap().0.len(),
751        };
752        let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
753        let sliding_window = model.config().sliding_window;
754        let model_metadata = Arc::new(model.config().clone());
755        Ok(Arc::new(Mutex::new(VisionPipeline {
756            model,
757            tokenizer: tokenizer.into(),
758            chat_template: Arc::new(chat_template),
759            model_id: self.model_id.clone(),
760            metadata: Arc::new(GeneralMetadata {
761                max_seq_len,
762                llg_factory: Some(llg_factory),
763                is_xlora: false,
764                num_hidden_layers,
765                eos_tok: eos,
766                kind: self.kind.clone(),
767                no_kv_cache: false,
768                no_prefix_cache: !self.inner.supports_prefix_cacher(&config),
769                activation_dtype: dtype,
770                sliding_window,
771                cache_config,
772                cache_engine,
773                model_metadata: Some(model_metadata),
774                modalities: self.inner.modalities(&config)?,
775            }),
776            processor,
777            prefixer: self.inner.prefixer(&config),
778            preprocessor_config: Arc::new(preprocessor_config),
779            topology: self.config.topology.clone(),
780            silent,
781            template_filename: paths.get_template_filename().clone(),
782            generation_config: paths.get_gen_conf_filename().cloned(),
783            config,
784            processor_filename: paths.get_processor_config().clone(),
785            preprocessor_filename: paths.get_preprocessor_config().clone(),
786            mapper: pipeline_mapper,
787            imatrix: self.config.imatrix.clone(),
788        })))
789    }
790
791    fn get_id(&self) -> String {
792        self.model_id.to_string()
793    }
794
795    fn get_kind(&self) -> ModelKind {
796        self.kind.clone()
797    }
798}
799
800impl PreProcessingMixin for VisionPipeline {
801    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
802        Some(self.chat_template.clone())
803    }
804    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
805        Some(self.preprocessor_config.clone())
806    }
807    fn get_processor(&self) -> Arc<dyn super::Processor> {
808        self.processor.clone()
809    }
810}
811
812impl IsqPipelineMixin for VisionPipeline {
813    fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
814        let device = self.device().clone();
815        self.model
816            .quantize(
817                Some(dtype),
818                device,
819                self.topology.as_ref(),
820                self.silent,
821                self.imatrix.as_ref().map(ImatrixDataSource::File),
822                IsqOrganization::Default,
823                None,
824                UqffFullSer {
825                    tokenizer: &self.tokenizer,
826                    template_filename: &self.template_filename,
827                    generation_config: self.generation_config.as_ref(),
828                    config: self.config.clone(),
829                    processor_filename: &self.processor_filename,
830                    preprocessor_filename: &self.preprocessor_filename,
831                },
832                Arc::new(MultiProgress::new()),
833            )
834            .map_err(anyhow::Error::msg)
835    }
836}
837
838impl CacheManagerMixin for VisionPipeline {
839    fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
840        if matches!(self.model.cache(), EitherCache::Full(_)) {
841            FullCacheManager.clone_in_cache(self, seqs, false)
842        } else {
843            NormalCacheManager.clone_in_cache(self, seqs, false)
844        }
845    }
846    fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
847        if matches!(self.model.cache(), EitherCache::Full(_)) {
848            FullCacheManager.clone_out_cache(self, seqs, false)
849        } else {
850            NormalCacheManager.clone_out_cache(self, seqs, false)
851        }
852    }
853    fn set_none_cache(
854        &self,
855        seqs: &mut [&mut Sequence],
856        reset_non_granular: bool,
857        modify_draft_cache: bool,
858
859        load_preallocated_cache: bool,
860    ) {
861        if matches!(self.model.cache(), EitherCache::Full(_)) {
862            FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false);
863        } else {
864            NormalCacheManager.set_none_cache(
865                self,
866                seqs,
867                modify_draft_cache,
868                load_preallocated_cache,
869            );
870        }
871        if reset_non_granular {
872            self.reset_non_granular_state()
873        }
874    }
875    fn cache(&self) -> &EitherCache {
876        self.model.cache()
877    }
878}
879
880impl MetadataMixin for VisionPipeline {
881    fn device(&self) -> Device {
882        self.model.device().clone()
883    }
884    fn get_metadata(&self) -> Arc<GeneralMetadata> {
885        self.metadata.clone()
886    }
887    fn name(&self) -> String {
888        self.model_id.clone()
889    }
890    fn reset_non_granular_state(&self) {}
891    fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
892        Some(self.tokenizer.clone())
893    }
894    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
895        Some(&*self.mapper)
896    }
897}
898
899#[async_trait::async_trait]
900impl Pipeline for VisionPipeline {
901    fn forward_inputs(
902        &mut self,
903        inputs: Box<dyn Any>,
904        return_raw_logits: bool,
905    ) -> candle_core::Result<ForwardInputsResult> {
906        let ModelInputs {
907            input_ids,
908            seqlen_offsets,
909            context_lens,
910            position_ids,
911            pixel_values,
912            model_specific_args,
913            paged_attn_meta,
914            flash_meta,
915        } = *inputs.downcast::<ModelInputs>().expect("Downcast failed.");
916        let metadata = self.get_metadata();
917        let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
918            (Some(engine), Some(meta)) => Some((engine.get_kv_cache().clone(), meta)),
919            (Some(_), None) => {
920                // This can happen if Rust-side user code is wrong
921                candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
922            }
923            (None, Some(_)) => {
924                // This should never happen but we handle it anyway
925                candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
926            }
927            (None, None) => None,
928        };
929        let logits = self.model.forward(
930            &input_ids,
931            pixel_values,
932            &seqlen_offsets,
933            context_lens,
934            position_ids,
935            model_specific_args,
936            paged_attn_meta,
937            &flash_meta,
938        )?;
939        if return_raw_logits {
940            Ok(ForwardInputsResult::RawLogits { logits })
941        } else {
942            Ok(ForwardInputsResult::CausalGeneration { logits })
943        }
944    }
945    async fn sample_causal_gen(
946        &self,
947        seqs: &mut [&mut Sequence],
948        logits: Vec<Tensor>,
949        prefix_cacher: &mut PrefixCacheManagerV2,
950        disable_eos_stop: bool,
951        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
952    ) -> Result<(), candle_core::Error> {
953        sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
954    }
955    fn category(&self) -> ModelCategory {
956        ModelCategory::Vision {
957            prefixer: self.prefixer.clone(),
958        }
959    }
960}
961
962impl AnyMoePipelineMixin for VisionPipeline {
963    fn amoe_finish_training(&mut self, gate_model_id: Option<String>) -> candle_core::Result<()> {
964        self.model.finish_training(gate_model_id)
965    }
966    fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
967        self.model.get_vars()
968    }
969    fn amoe_base_model_trainable_params(&self) -> usize {
970        self.model.trainable_params()
971    }
972    fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
973        self.model.take_cached_gating_outputs()
974    }
975    fn amoe_create_layers(
976        &mut self,
977        model_ids: Vec<String>,
978        token: &TokenSource,
979        revision: Option<String>,
980        match_regex: &str,
981        config: crate::amoe::AnyMoeConfig,
982        dtype: candle_core::DType,
983        dev: &Device,
984        (prefix, mlp): (String, String),
985        layers: Vec<usize>,
986        expert_type: AnyMoeExpertType,
987        silent: bool,
988        gate_model_id: Option<String>,
989    ) -> candle_core::Result<()> {
990        let mut vbs = Vec::new();
991        // Precompile regex here
992        let regex = Regex::new(match_regex).map_err(candle_core::Error::msg)?;
993        for model_id in model_ids {
994            let model_id_str = &model_id;
995            let model_id = Path::new(&model_id);
996
997            let api = {
998                let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
999                let mut api = ApiBuilder::from_cache(cache)
1000                    .with_progress(!silent)
1001                    .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1002                if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1003                    api = api.with_cache_dir(x.into());
1004                }
1005                api.build().map_err(candle_core::Error::msg)?
1006            };
1007            let revision = revision.clone().unwrap_or("main".to_string());
1008            let api = api.repo(Repo::with_revision(
1009                model_id_str.clone(),
1010                RepoType::Model,
1011                revision.clone(),
1012            ));
1013
1014            let mut filenames = vec![];
1015            for rfilename in
1016                api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1017            {
1018                filenames.push(api_get_file!(api, &rfilename, model_id));
1019            }
1020
1021            let regex = regex.clone();
1022            let match_regex_clone = match_regex.to_string();
1023            let layers_clone = layers.clone();
1024            let vb = from_mmaped_safetensors(
1025                filenames,
1026                vec![],
1027                Some(dtype),
1028                dev,
1029                vec![None],
1030                silent,
1031                None,
1032                move |key| {
1033                    if regex.is_match(&key) {
1034                        // Idx of the last char of the layer id, +1
1035                        // Assumes N.MLP
1036                        let last_layer_idx = key.find(&match_regex_clone).unwrap() - 1;
1037                        let first_layer_idx = key[..last_layer_idx].rfind('.').unwrap();
1038                        let layer_n = key[first_layer_idx + 1..last_layer_idx]
1039                            .parse::<usize>()
1040                            .unwrap();
1041                        layers_clone.contains(&layer_n) || layers_clone.is_empty()
1042                    } else {
1043                        false
1044                    }
1045                },
1046                Arc::new(|_| DeviceForLoadTensor::Base),
1047            )?;
1048            vbs.push(vb);
1049        }
1050
1051        let gate_vb = if let Some(gate_model_id) = gate_model_id {
1052            let model_id_str = &gate_model_id;
1053            let model_id = Path::new(&gate_model_id);
1054
1055            let api = {
1056                let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1057                let mut api = ApiBuilder::from_cache(cache)
1058                    .with_progress(!silent)
1059                    .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1060                if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1061                    api = api.with_cache_dir(x.into());
1062                }
1063                api.build().map_err(candle_core::Error::msg)?
1064            };
1065            let revision = revision.clone().unwrap_or("main".to_string());
1066            let api = api.repo(Repo::with_revision(
1067                model_id_str.clone(),
1068                RepoType::Model,
1069                revision.clone(),
1070            ));
1071
1072            let mut gate_filenames = vec![];
1073            for rfilename in
1074                api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1075            {
1076                gate_filenames.push(api_get_file!(api, &rfilename, model_id));
1077            }
1078            assert_eq!(
1079                gate_filenames.len(),
1080                1,
1081                "Gate model ID must contain only one .safetensors file"
1082            );
1083
1084            let vb = from_mmaped_safetensors(
1085                gate_filenames.clone(),
1086                vec![],
1087                Some(dtype),
1088                dev,
1089                vec![None],
1090                silent,
1091                None,
1092                |_| true,
1093                Arc::new(|_| DeviceForLoadTensor::Base),
1094            )?;
1095            info!(
1096                "Loaded gating layers from `{}`",
1097                gate_filenames[0].display()
1098            );
1099            Some(vb)
1100        } else {
1101            None
1102        };
1103
1104        self.model
1105            .create_anymoe_layers(vbs, config, (prefix, mlp), layers, expert_type, gate_vb)
1106    }
1107    fn amoe_supported(&self) -> bool {
1108        self.model.amoe_supported()
1109    }
1110}