mistralrs_core/pipeline/
vision.rs

1use super::isq::ImatrixDataSource;
2use super::isq::UqffFullSer;
3use super::{
4    get_model_paths, get_xlora_paths, AdapterKind, AnyMoePipelineMixin, AutoVisionLoader,
5    CacheManager, CacheManagerMixin, EitherCache, ForwardInputsResult, Gemma3Loader,
6    GeneralMetadata, IsqPipelineMixin, Loader, MetadataMixin, MiniCpmOLoader, ModelCategory,
7    ModelKind, ModelPaths, MultimodalPromptPrefixer, Phi4MMLoader, PreProcessingMixin, Processor,
8    Qwen2VLLoader, TokenSource, VLlama4Loader, VLlamaLoader, VisionModel, VisionModelLoader,
9};
10use super::{
11    Idefics2Loader, Idefics3Loader, LLaVALoader, LLaVANextLoader, Mistral3Loader, Phi3VLoader,
12    Qwen2_5VLLoader, VisionLoaderType,
13};
14use crate::device_map::{self, DeviceMapper};
15use crate::distributed::{self, WorkerTransferData};
16use crate::kv_cache::{FullCacheManager, NormalCacheManager};
17use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine};
18use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
19use crate::pipeline::llg::build_llg_factory;
20use crate::pipeline::loaders::auto_device_map;
21use crate::pipeline::loaders::QuantizationConfigShim;
22use crate::pipeline::sampling::sample_and_add_toks;
23use crate::pipeline::text_models_inputs_processor::make_prompt_chunk;
24use crate::pipeline::{get_chat_template, ChatTemplate, IsqOrganization, LocalModelPaths};
25use crate::prefix_cacher::PrefixCacheManagerV2;
26use crate::sequence::Sequence;
27use crate::utils::tokenizer::get_tokenizer;
28use crate::utils::varbuilder_utils::DeviceForLoadTensor;
29use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
30use crate::vision_models::preprocessor_config::PreProcessorConfig;
31use crate::vision_models::processor_config::ProcessorConfig;
32use crate::vision_models::ModelInputs;
33use crate::{
34    api_dir_list, api_get_file, get_paths, get_uqff_paths, vision_normal_model_loader,
35    vision_normal_model_loader_sharded, AnyMoeExpertType, DeviceMapSetting, Ordering,
36    PagedAttentionConfig, Pipeline, Topology, TryIntoDType, GLOBAL_HF_CACHE,
37};
38use anyhow::Result;
39use candle_core::{Device, Tensor, Var};
40use hf_hub::Cache;
41use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
42use indicatif::MultiProgress;
43use mistralrs_quant::log::once_log_info;
44use mistralrs_quant::{AfqLayer, GgufMatMul, HqqLayer, IsqType, QuantizedSerdeType};
45use rand_isaac::Isaac64Rng;
46use regex_automata::meta::Regex;
47use std::any::Any;
48use std::borrow::Cow;
49use std::num::NonZeroUsize;
50use std::path::{Path, PathBuf};
51use std::str::FromStr;
52use std::sync::{Arc, RwLock};
53use std::time::Instant;
54use std::{env, fs};
55use tokenizers::Tokenizer;
56use tokio::sync::Mutex;
57use tracing::{info, warn};
58
59pub struct VisionPipeline {
60    model: Box<dyn VisionModel + Send + Sync>,
61    tokenizer: Arc<Tokenizer>,
62    chat_template: Arc<ChatTemplate>,
63    model_id: String,
64    metadata: Arc<GeneralMetadata>,
65    processor: Arc<dyn Processor + Send + Sync>,
66    preprocessor_config: Arc<PreProcessorConfig>,
67    topology: Option<Topology>,
68    silent: bool,
69    prefixer: Arc<dyn MultimodalPromptPrefixer>,
70    mapper: Box<dyn DeviceMapper + Send + Sync>,
71
72    // For full UQFF serialization
73    template_filename: Option<PathBuf>,
74    generation_config: Option<PathBuf>,
75    config: String,
76    processor_filename: Option<PathBuf>,
77    preprocessor_filename: Option<PathBuf>,
78    imatrix: Option<PathBuf>,
79}
80
81/// A loader for a vision (non-quantized) model.
82pub struct VisionLoader {
83    inner: Box<dyn VisionModelLoader>,
84    model_id: String,
85    config: VisionSpecificConfig,
86    kind: ModelKind,
87    chat_template: Option<String>,
88    tokenizer_json: Option<String>,
89    xlora_model_id: Option<String>,
90    xlora_order: Option<Ordering>,
91    token_source: RwLock<Option<TokenSource>>,
92    revision: RwLock<Option<String>>,
93    from_uqff: RwLock<Option<Vec<PathBuf>>>,
94    jinja_explicit: Option<String>,
95    hf_cache_path: Option<PathBuf>,
96    lora_adapter_ids: Option<Vec<String>>,
97}
98
99#[derive(Default)]
100/// A builder for a loader for a vision (non-quantized) model.
101pub struct VisionLoaderBuilder {
102    model_id: Option<String>,
103    config: VisionSpecificConfig,
104    kind: ModelKind,
105    chat_template: Option<String>,
106    tokenizer_json: Option<String>,
107    jinja_explicit: Option<String>,
108    hf_cache_path: Option<PathBuf>,
109    lora_adapter_ids: Option<Vec<String>>,
110}
111
112#[derive(Clone, Default)]
113/// Config specific to loading a vision model.
114pub struct VisionSpecificConfig {
115    pub prompt_chunksize: Option<NonZeroUsize>,
116    pub topology: Option<Topology>,
117    pub write_uqff: Option<PathBuf>,
118    pub from_uqff: Option<Vec<PathBuf>>,
119    pub max_edge: Option<u32>,
120    pub imatrix: Option<PathBuf>,
121    pub calibration_file: Option<PathBuf>,
122    pub hf_cache_path: Option<PathBuf>,
123}
124
125impl VisionLoaderBuilder {
126    pub fn new(
127        config: VisionSpecificConfig,
128        chat_template: Option<String>,
129        tokenizer_json: Option<String>,
130        model_id: Option<String>,
131        jinja_explicit: Option<String>,
132    ) -> Self {
133        Self {
134            config,
135            chat_template,
136            tokenizer_json,
137            model_id,
138            jinja_explicit,
139            kind: ModelKind::Normal,
140            hf_cache_path: None,
141            ..Default::default()
142        }
143    }
144
145    pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
146        self.hf_cache_path = Some(hf_cache_path);
147        self
148    }
149
150    pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
151        self.kind = ModelKind::Adapter {
152            adapter: AdapterKind::Lora,
153        };
154        self.lora_adapter_ids = Some(lora_adapter_ids);
155        self
156    }
157
158    pub fn build(self, loader: Option<VisionLoaderType>) -> Box<dyn Loader> {
159        let loader: Box<dyn VisionModelLoader> = match loader {
160            Some(VisionLoaderType::Phi3V) => Box::new(Phi3VLoader),
161            Some(VisionLoaderType::Idefics2) => Box::new(Idefics2Loader),
162            Some(VisionLoaderType::LLaVANext) => Box::new(LLaVANextLoader),
163            Some(VisionLoaderType::LLaVA) => Box::new(LLaVALoader),
164            Some(VisionLoaderType::VLlama) => Box::new(VLlamaLoader),
165            Some(VisionLoaderType::Qwen2VL) => Box::new(Qwen2VLLoader),
166            Some(VisionLoaderType::Idefics3) => Box::new(Idefics3Loader),
167            Some(VisionLoaderType::MiniCpmO) => Box::new(MiniCpmOLoader),
168            Some(VisionLoaderType::Phi4MM) => Box::new(Phi4MMLoader),
169            Some(VisionLoaderType::Qwen2_5VL) => Box::new(Qwen2_5VLLoader),
170            Some(VisionLoaderType::Gemma3) => Box::new(Gemma3Loader),
171            Some(VisionLoaderType::Mistral3) => Box::new(Mistral3Loader),
172            Some(VisionLoaderType::Llama4) => Box::new(VLlama4Loader),
173            None => Box::new(AutoVisionLoader),
174        };
175        Box::new(VisionLoader {
176            inner: loader,
177            model_id: self.model_id.unwrap(),
178            config: self.config,
179            kind: self.kind,
180            chat_template: self.chat_template,
181            tokenizer_json: self.tokenizer_json,
182            xlora_model_id: None,
183            xlora_order: None,
184            jinja_explicit: self.jinja_explicit,
185            token_source: RwLock::new(None),
186            revision: RwLock::new(None),
187            from_uqff: RwLock::new(None),
188            hf_cache_path: self.hf_cache_path,
189            lora_adapter_ids: self.lora_adapter_ids,
190        })
191    }
192}
193
194impl Loader for VisionLoader {
195    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
196    fn load_model_from_hf(
197        &self,
198        revision: Option<String>,
199        token_source: TokenSource,
200        dtype: &dyn TryIntoDType,
201        device: &Device,
202        silent: bool,
203        mapper: DeviceMapSetting,
204        in_situ_quant: Option<IsqType>,
205        paged_attn_config: Option<PagedAttentionConfig>,
206    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
207        let cache = self
208            .hf_cache_path
209            .clone()
210            .map(Cache::new)
211            .unwrap_or_default();
212        GLOBAL_HF_CACHE.get_or_init(|| cache);
213
214        let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths!(
215            LocalModelPaths,
216            &token_source,
217            revision.clone(),
218            self,
219            None,
220            None,
221            silent,
222            self.config.from_uqff.is_some()
223        );
224        if let Some(from_uqff) = self.config.from_uqff.clone() {
225            *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
226        }
227        *self
228            .token_source
229            .write()
230            .expect("Failed to write to token source") = Some(token_source);
231        *self.revision.write().expect("Failed to write to revision") = revision;
232        self.load_model_from_path(
233            &paths?,
234            dtype,
235            device,
236            silent,
237            mapper,
238            in_situ_quant,
239            paged_attn_config,
240        )
241    }
242
243    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
244    fn load_model_from_path(
245        &self,
246        paths: &Box<dyn ModelPaths>,
247        dtype: &dyn TryIntoDType,
248        device: &Device,
249        silent: bool,
250        mut mapper: DeviceMapSetting,
251        mut in_situ_quant: Option<IsqType>,
252        mut paged_attn_config: Option<PagedAttentionConfig>,
253    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
254        let config = std::fs::read_to_string(paths.get_config_filename())?;
255
256        if !self.inner.supports_paged_attention(&config) {
257            paged_attn_config = None;
258        }
259
260        let use_nccl = mistralrs_quant::distributed::use_nccl();
261
262        let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
263            let payload: WorkerTransferData = serde_json::from_str(&payload)?;
264            let WorkerTransferData::Init { id: _, worker_rank } = payload;
265            vec![candle_core::Device::new_cuda_with_stream(worker_rank + 1)?]
266        } else if use_nccl {
267            vec![candle_core::Device::new_cuda_with_stream(0)?]
268        } else {
269            device_map::get_all_similar_devices(device)?
270        };
271        let device = if use_nccl {
272            available_devices[0].clone()
273        } else {
274            device.clone()
275        };
276
277        // If auto, convert to Map if not using nccl
278        if use_nccl {
279            mapper = DeviceMapSetting::DummyNccl {
280                nm_device: available_devices[0].clone(),
281            };
282        } else if let DeviceMapSetting::Auto(mut params) = mapper.clone() {
283            // We can promote to vision params if we get text params
284            params = params.maybe_promote_to_vision();
285
286            // Initial dtype
287            let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
288
289            // Disable ISQ if we are loading a prequantized model.
290            if QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)? != 1 {
291                in_situ_quant = None;
292            }
293
294            // ISQ or UQFF: quantized path
295            // Match logic below where UQFF has priority
296            let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
297                if let Some(serialized) = &*self.from_uqff.read().unwrap() {
298                    let weight_pack_factor = {
299                        let ser_artifacts = unsafe {
300                            candle_core::safetensors::MmapedSafetensors::multi(serialized)?
301                        };
302                        let mut total_pack_factors = 0;
303                        let total_tensors = ser_artifacts.tensors().len();
304                        for (_, artifact) in ser_artifacts.tensors() {
305                            let artifact = artifact.data();
306                            // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
307                            let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
308                            let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
309                            {
310                                QuantizedSerdeType::Hqq => {
311                                    HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
312                                        .pack_factor(dtype)
313                                }
314                                QuantizedSerdeType::Gguf => {
315                                    GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
316                                        .pack_factor(dtype)
317                                }
318                                QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
319                                QuantizedSerdeType::Unquant => 1,
320                                QuantizedSerdeType::Afq => {
321                                    AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
322                                        .pack_factor(dtype)
323                                }
324                            };
325                            total_pack_factors += pack_factor;
326                        }
327
328                        total_pack_factors / total_tensors
329                    };
330
331                    let layer_sizes_in_bytes =
332                        self.inner
333                            .layer_sizes_in_bytes(&config, dtype, weight_pack_factor)?;
334                    let non_mapped_size_in_bytes =
335                        self.inner
336                            .non_mapped_size_in_bytes(&config, dtype, weight_pack_factor)?;
337                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
338                    (
339                        layer_sizes_in_bytes,
340                        non_mapped_size_in_bytes,
341                        layer_sizes_sum + non_mapped_size_in_bytes,
342                    )
343                } else if let Some(isq) = in_situ_quant {
344                    let weight_pack_factor = isq.pack_factor(dtype);
345                    let layer_sizes_in_bytes =
346                        self.inner
347                            .layer_sizes_in_bytes(&config, dtype, weight_pack_factor)?;
348                    let non_mapped_size_in_bytes =
349                        self.inner
350                            .non_mapped_size_in_bytes(&config, dtype, weight_pack_factor)?;
351                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
352                    (
353                        layer_sizes_in_bytes,
354                        non_mapped_size_in_bytes,
355                        layer_sizes_sum + non_mapped_size_in_bytes,
356                    )
357                } else {
358                    // Be sure to get the weight pack factor here; we might be loading a prequantized model.
359                    let weight_pack_factor =
360                        QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)?;
361                    let layer_sizes_in_bytes =
362                        self.inner
363                            .layer_sizes_in_bytes(&config, dtype, weight_pack_factor)?;
364                    let non_mapped_size_in_bytes =
365                        self.inner
366                            .non_mapped_size_in_bytes(&config, dtype, weight_pack_factor)?;
367                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
368                    (
369                        layer_sizes_in_bytes,
370                        non_mapped_size_in_bytes,
371                        layer_sizes_sum + non_mapped_size_in_bytes,
372                    )
373                };
374
375            let new = auto_device_map::get_device_layers(
376                &*self.inner,
377                &config,
378                self.inner.num_layers(&config)?,
379                layer_sizes_in_bytes,
380                non_mapped_size_in_bytes,
381                total_model_size_in_bytes,
382                &available_devices,
383                dtype,
384                &params,
385                params.max_seq_len(),
386                paged_attn_config.as_ref(),
387            )?;
388            mapper = DeviceMapSetting::Map(new);
389        }
390
391        let pipeline_mapper = mapper.into_mapper(
392            self.inner.num_layers(&config)?,
393            &device,
394            self.config.topology.as_ref(),
395        )?;
396        let mapper = mapper.into_mapper(
397            self.inner.num_layers(&config)?,
398            &device,
399            self.config.topology.as_ref(),
400        )?;
401        let mut layer_devices = Vec::new();
402        for layer in 0..self.inner.num_layers(&config)? {
403            let device = mapper.device_for(layer, false).cloned();
404            layer_devices.push(device);
405        }
406        let dtype = mapper.get_min_dtype(dtype)?;
407
408        // TODO: PagedAttention is not supported with CPU for now.
409        // This check is not really necessary because `get_device_layers` should prevent it.
410        let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
411        if mapping_uses_cpu {
412            warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
413            paged_attn_config = None;
414        }
415
416        info!("Model config: {:?}", self.inner.get_config_repr(&config)?);
417        if crate::using_flash_attn() {
418            once_log_info("FlashAttention is enabled.");
419        }
420
421        // Logic for ISQ here: if no calibration (i.e imatrix), then allow immediate ISQ. Otherwise, back to normal.
422        let mut loading_isq = if self.config.imatrix.is_none()
423            && self.config.calibration_file.is_none()
424            && !device.is_cuda()
425            && self.config.write_uqff.is_none()
426            && in_situ_quant.is_some()
427        {
428            let predicates = self.inner.immediate_isq_predicates(&config)?;
429            info!("Applying ISQ to {in_situ_quant:?}");
430            if predicates.is_empty() {
431                warn!("No predicates for this model and ISQ setting detected. ISQ will not be applied to any weights!");
432            }
433            mistralrs_quant::set_immediate_isq(in_situ_quant, predicates);
434            false
435        } else {
436            in_situ_quant.is_some()
437        };
438
439        if let Some(ref topology) = self.config.topology {
440            loading_isq |= topology
441                .0
442                .iter()
443                .any(|layer| layer.as_ref().is_some_and(|layer| layer.isq.is_some()));
444        }
445
446        if self.config.imatrix.is_some() && self.config.calibration_file.is_some() {
447            anyhow::bail!(
448                "`imatrix` and `calibration_file` were both specified, this is not allowed."
449            );
450        }
451
452        // Load onto the regular device if not using isq or if the calibration file is specified
453        let load_device = if !loading_isq || self.config.calibration_file.is_some() {
454            loading_isq = false;
455            device.clone()
456        } else {
457            Device::Cpu
458        };
459
460        let attention_mechanism = if paged_attn_config.is_some() {
461            AttentionImplementation::PagedAttention
462        } else {
463            AttentionImplementation::Eager
464        };
465
466        let multi_progress = Arc::new(MultiProgress::new());
467
468        let mut model = if use_nccl {
469            let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
470                dtype,
471                &device,
472                &available_devices,
473                silent,
474                &config,
475                loading_isq,
476                self.config.from_uqff.is_some(),
477                IsqOrganization::Default,
478                &*self.inner,
479                paths.as_ref(),
480            )?;
481
482            // Special case for where things can be more optimially loaded.
483            match self.kind {
484                ModelKind::Normal => vision_normal_model_loader_sharded!(
485                    sharded_vb,
486                    config,
487                    self.inner,
488                    mapper,
489                    loading_isq,
490                    device.clone(),
491                    attention_mechanism,
492                    multi_progress.clone(),
493                ),
494                _ => unreachable!(),
495            }
496        } else {
497            match self.kind {
498                ModelKind::Normal => vision_normal_model_loader!(
499                    paths,
500                    Some(dtype),
501                    &load_device,
502                    layer_devices.clone(),
503                    config,
504                    self.inner,
505                    silent,
506                    mapper,
507                    loading_isq,
508                    self.config.from_uqff.is_some(),
509                    device.clone(),
510                    attention_mechanism,
511                    multi_progress,
512                ),
513                _ => unreachable!(),
514            }
515        };
516
517        // Handle the Gemma 3 1b case here
518        let preprocessor_config: PreProcessorConfig = match paths.get_preprocessor_config().as_ref()
519        {
520            Some(preprocessor_config) => {
521                serde_json::from_str(&fs::read_to_string(preprocessor_config).unwrap()).unwrap()
522            }
523            None => PreProcessorConfig::default(),
524        };
525        let processor_config: Option<ProcessorConfig> = paths
526            .get_processor_config()
527            .as_ref()
528            .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
529
530        let processor = self.inner.get_processor(
531            &config,
532            processor_config,
533            preprocessor_config.clone(),
534            self.config.max_edge,
535        ); //There are always some repos that don't properly handle config position, for example... LLaVA
536
537        let tokenizer = get_tokenizer(
538            paths.get_tokenizer_filename(),
539            Some(processor.get_special_tokens()),
540        )?;
541
542        let gen_conf: Option<GenerationConfig> = paths
543            .get_gen_conf_filename()
544            .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
545        let chat_template_explicit = paths
546            .get_chat_template_explicit()
547            .as_ref()
548            .map(|x| x.to_string_lossy().to_string());
549        let chat_template = get_chat_template(
550            paths,
551            self.jinja_explicit.as_ref(),
552            chat_template_explicit.as_ref(),
553            self.chat_template.as_ref(),
554            None,
555        );
556
557        if let Some(calibration_file) = &self.config.calibration_file {
558            let calibration_data = std::fs::read_to_string(calibration_file)?;
559            // Tokenize, don't add bos yet
560            let tokens = tokenizer
561                .encode_fast(calibration_data, false)
562                .map_err(anyhow::Error::msg)?
563                .get_ids()
564                .to_vec();
565            info!(
566                "Collecting imatrix from calibration file `{}` of {} tokens.",
567                calibration_file.display(),
568                tokens.len()
569            );
570            let bos_toks = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();
571            let bos_tok_id = tokenizer
572                .token_to_id(&bos_toks[0])
573                .expect("Somehow the bos token is not present.");
574
575            // NOTE: We ONLY calibrate the text bits of these models!!
576            // So only those should be tracked!
577            model.begin_track_stats()?;
578
579            const CHUNK_SIZE: usize = 1024;
580            let n_chunks: usize = tokens.len().div_ceil(CHUNK_SIZE);
581            let start = Instant::now();
582            for (i, chunk) in tokens.chunks(CHUNK_SIZE).enumerate() {
583                let chunk = [vec![bos_tok_id], chunk.to_vec()].concat();
584                let chunk_len = chunk.len();
585
586                let start = Instant::now();
587                let inputs = make_prompt_chunk(
588                    0,
589                    vec![&chunk],
590                    &[0],
591                    &load_device,
592                    None,
593                    false,
594                    None,
595                    None,
596                )?;
597                let _ = model.forward(
598                    &inputs.input,
599                    None, // NOTE: We ONLY calibrate the text bits of these models!!
600                    &inputs.positions,
601                    inputs.context_lens,
602                    inputs.position_ids,
603                    model.default_model_specific_args(&inputs.input),
604                    None,
605                    &inputs.flash_meta,
606                )?;
607                match model.cache_mut() {
608                    EitherCache::Full(full) => {
609                        for layer in &mut *full.lock() {
610                            *layer = None
611                        }
612                    }
613                    EitherCache::Normal(normal) => {
614                        for layer in &mut *normal.lock().unwrap().0 {
615                            layer.reset();
616                        }
617                    }
618                }
619                let end = Instant::now();
620                info!(
621                    "Processed chunk {}/{n_chunks} ({chunk_len} tokens), {:.2}s",
622                    i + 1,
623                    end.duration_since(start).as_secs_f32()
624                );
625            }
626            load_device.synchronize()?;
627            let end = Instant::now();
628            info!(
629                "Finished collecting imatrix in {:.2}s",
630                end.duration_since(start).as_secs_f32()
631            );
632        }
633
634        // Only if loading from UQFF
635        if (loading_isq || self.config.topology.is_some()) && self.config.from_uqff.is_none() {
636            let imatrix_source = match (
637                self.config.imatrix.as_ref(),
638                self.config.calibration_file.is_some(),
639            ) {
640                (None, false) => None,
641                (Some(file), false) => Some(ImatrixDataSource::File(file)),
642                (None, true) => Some(ImatrixDataSource::Collected),
643                (Some(_), true) => unreachable!(),
644            };
645            model.quantize(
646                in_situ_quant,
647                device.clone(),
648                self.config.topology.as_ref(),
649                silent,
650                imatrix_source,
651                IsqOrganization::Default,
652                self.config.write_uqff.as_ref(),
653                UqffFullSer {
654                    tokenizer: &tokenizer,
655                    template_filename: paths.get_template_filename(),
656                    generation_config: paths.get_gen_conf_filename(),
657                    config: config.clone(),
658                    processor_filename: paths.get_processor_config(),
659                    preprocessor_filename: paths.get_preprocessor_config(),
660                },
661                Arc::new(MultiProgress::new()),
662            )?;
663        } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
664            model.load_from_artifacts(
665                device.clone(),
666                self.config.topology.as_ref(),
667                silent,
668                from_uqff,
669            )?;
670        }
671
672        let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
673            anyhow::ensure!(
674                !matches!(self.kind, ModelKind::Adapter { .. }),
675                "PagedAttention does not support adapter models."
676            );
677            let cache_config = calculate_cache_config(
678                paged_attn_config.mem_gpu,
679                paged_attn_config.mem_cpu,
680                paged_attn_config.block_size,
681                dtype,
682                paged_attn_config.cache_type,
683                model.config(),
684                &device,
685                &layer_devices,
686                silent,
687            )?;
688            let cache_engine =
689                CacheEngine::new(model.config(), &cache_config, dtype, &device, layer_devices)?;
690            (Some(cache_config), Some(cache_engine))
691        } else {
692            (None, None)
693        };
694
695        let max_seq_len = model.max_seq_len();
696        let llg_factory = build_llg_factory(tokenizer.clone())?;
697        let num_hidden_layers = match model.cache() {
698            EitherCache::Full(full) => full.lock().len(),
699            EitherCache::Normal(normal) => normal.lock().unwrap().0.len(),
700        };
701        let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
702        let sliding_window = model.config().sliding_window;
703        let model_metadata = Arc::new(model.config().clone());
704        Ok(Arc::new(Mutex::new(VisionPipeline {
705            model,
706            tokenizer: tokenizer.into(),
707            chat_template: Arc::new(chat_template),
708            model_id: self.model_id.clone(),
709            metadata: Arc::new(GeneralMetadata {
710                max_seq_len,
711                llg_factory: Some(llg_factory),
712                is_xlora: false,
713                num_hidden_layers,
714                eos_tok: eos,
715                kind: self.kind.clone(),
716                no_kv_cache: false,
717                no_prefix_cache: !self.inner.supports_prefix_cacher(&config),
718                activation_dtype: dtype,
719                sliding_window,
720                cache_config,
721                cache_engine,
722                prompt_chunksize: self.config.prompt_chunksize,
723                model_metadata: Some(model_metadata),
724                modalities: self.inner.modalities(&config)?,
725            }),
726            processor,
727            prefixer: self.inner.prefixer(&config),
728            preprocessor_config: Arc::new(preprocessor_config),
729            topology: self.config.topology.clone(),
730            silent,
731            template_filename: paths.get_template_filename().clone(),
732            generation_config: paths.get_gen_conf_filename().cloned(),
733            config,
734            processor_filename: paths.get_processor_config().clone(),
735            preprocessor_filename: paths.get_preprocessor_config().clone(),
736            mapper: pipeline_mapper,
737            imatrix: self.config.imatrix.clone(),
738        })))
739    }
740
741    fn get_id(&self) -> String {
742        self.model_id.to_string()
743    }
744
745    fn get_kind(&self) -> ModelKind {
746        self.kind.clone()
747    }
748}
749
750impl PreProcessingMixin for VisionPipeline {
751    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
752        Some(self.chat_template.clone())
753    }
754    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
755        Some(self.preprocessor_config.clone())
756    }
757    fn get_processor(&self) -> Arc<dyn super::Processor> {
758        self.processor.clone()
759    }
760}
761
762impl IsqPipelineMixin for VisionPipeline {
763    fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
764        let device = self.device().clone();
765        self.model
766            .quantize(
767                Some(dtype),
768                device,
769                self.topology.as_ref(),
770                self.silent,
771                self.imatrix.as_ref().map(ImatrixDataSource::File),
772                IsqOrganization::Default,
773                None,
774                UqffFullSer {
775                    tokenizer: &self.tokenizer,
776                    template_filename: &self.template_filename,
777                    generation_config: self.generation_config.as_ref(),
778                    config: self.config.clone(),
779                    processor_filename: &self.processor_filename,
780                    preprocessor_filename: &self.preprocessor_filename,
781                },
782                Arc::new(MultiProgress::new()),
783            )
784            .map_err(anyhow::Error::msg)
785    }
786}
787
788impl CacheManagerMixin for VisionPipeline {
789    fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
790        if matches!(self.model.cache(), EitherCache::Full(_)) {
791            FullCacheManager.clone_in_cache(self, seqs, false)
792        } else {
793            NormalCacheManager.clone_in_cache(self, seqs, false)
794        }
795    }
796    fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
797        if matches!(self.model.cache(), EitherCache::Full(_)) {
798            FullCacheManager.clone_out_cache(self, seqs, false)
799        } else {
800            NormalCacheManager.clone_out_cache(self, seqs, false)
801        }
802    }
803    fn set_none_cache(
804        &self,
805        seqs: &mut [&mut Sequence],
806        reset_non_granular: bool,
807        modify_draft_cache: bool,
808
809        load_preallocated_cache: bool,
810    ) {
811        if matches!(self.model.cache(), EitherCache::Full(_)) {
812            FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false);
813        } else {
814            NormalCacheManager.set_none_cache(
815                self,
816                seqs,
817                modify_draft_cache,
818                load_preallocated_cache,
819            );
820        }
821        if reset_non_granular {
822            self.reset_non_granular_state()
823        }
824    }
825    fn cache(&self) -> &EitherCache {
826        self.model.cache()
827    }
828}
829
830impl MetadataMixin for VisionPipeline {
831    fn device(&self) -> Device {
832        self.model.device().clone()
833    }
834    fn get_metadata(&self) -> Arc<GeneralMetadata> {
835        self.metadata.clone()
836    }
837    fn name(&self) -> String {
838        self.model_id.clone()
839    }
840    fn reset_non_granular_state(&self) {}
841    fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
842        Some(self.tokenizer.clone())
843    }
844    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
845        Some(&*self.mapper)
846    }
847}
848
849#[async_trait::async_trait]
850impl Pipeline for VisionPipeline {
851    fn forward_inputs(
852        &mut self,
853        inputs: Box<dyn Any>,
854        return_raw_logits: bool,
855    ) -> candle_core::Result<ForwardInputsResult> {
856        let ModelInputs {
857            input_ids,
858            seqlen_offsets,
859            context_lens,
860            position_ids,
861            pixel_values,
862            model_specific_args,
863            paged_attn_meta,
864            flash_meta,
865        } = *inputs.downcast::<ModelInputs>().expect("Downcast failed.");
866        let metadata = self.get_metadata();
867        let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
868            (Some(engine), Some(meta)) => Some((engine.get_kv_cache().clone(), meta)),
869            (Some(_), None) => {
870                // This can happen if Rust-side user code is wrong
871                candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
872            }
873            (None, Some(_)) => {
874                // This should never happen but we handle it anyway
875                candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
876            }
877            (None, None) => None,
878        };
879        let logits = self.model.forward(
880            &input_ids,
881            pixel_values,
882            &seqlen_offsets,
883            context_lens,
884            position_ids,
885            model_specific_args,
886            paged_attn_meta,
887            &flash_meta,
888        )?;
889        if return_raw_logits {
890            Ok(ForwardInputsResult::RawLogits { logits })
891        } else {
892            Ok(ForwardInputsResult::CausalGeneration { logits })
893        }
894    }
895    async fn sample_causal_gen(
896        &self,
897        seqs: &mut [&mut Sequence],
898        logits: Vec<Tensor>,
899        prefix_cacher: &mut PrefixCacheManagerV2,
900        disable_eos_stop: bool,
901        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
902    ) -> Result<(), candle_core::Error> {
903        sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
904    }
905    fn category(&self) -> ModelCategory {
906        ModelCategory::Vision {
907            prefixer: self.prefixer.clone(),
908        }
909    }
910}
911
912impl AnyMoePipelineMixin for VisionPipeline {
913    fn amoe_finish_training(&mut self, gate_model_id: Option<String>) -> candle_core::Result<()> {
914        self.model.finish_training(gate_model_id)
915    }
916    fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
917        self.model.get_vars()
918    }
919    fn amoe_base_model_trainable_params(&self) -> usize {
920        self.model.trainable_params()
921    }
922    fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
923        self.model.take_cached_gating_outputs()
924    }
925    fn amoe_create_layers(
926        &mut self,
927        model_ids: Vec<String>,
928        token: &TokenSource,
929        revision: Option<String>,
930        match_regex: &str,
931        config: crate::amoe::AnyMoeConfig,
932        dtype: candle_core::DType,
933        dev: &Device,
934        (prefix, mlp): (String, String),
935        layers: Vec<usize>,
936        expert_type: AnyMoeExpertType,
937        silent: bool,
938        gate_model_id: Option<String>,
939    ) -> candle_core::Result<()> {
940        let mut vbs = Vec::new();
941        // Precompile regex here
942        let regex = Regex::new(match_regex).map_err(candle_core::Error::msg)?;
943        for model_id in model_ids {
944            let model_id_str = &model_id;
945            let model_id = Path::new(&model_id);
946
947            let api = {
948                let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
949                let mut api = ApiBuilder::from_cache(cache)
950                    .with_progress(!silent)
951                    .with_token(get_token(token).map_err(candle_core::Error::msg)?);
952                if let Ok(x) = std::env::var("HF_HUB_CACHE") {
953                    api = api.with_cache_dir(x.into());
954                }
955                api.build().map_err(candle_core::Error::msg)?
956            };
957            let revision = revision.clone().unwrap_or("main".to_string());
958            let api = api.repo(Repo::with_revision(
959                model_id_str.clone(),
960                RepoType::Model,
961                revision.clone(),
962            ));
963
964            let mut filenames = vec![];
965            for rfilename in
966                api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
967            {
968                filenames.push(api_get_file!(api, &rfilename, model_id));
969            }
970
971            let regex = regex.clone();
972            let match_regex_clone = match_regex.to_string();
973            let layers_clone = layers.clone();
974            let vb = from_mmaped_safetensors(
975                filenames,
976                vec![],
977                Some(dtype),
978                dev,
979                vec![None],
980                silent,
981                None,
982                move |key| {
983                    if regex.is_match(&key) {
984                        // Idx of the last char of the layer id, +1
985                        // Assumes N.MLP
986                        let last_layer_idx = key.find(&match_regex_clone).unwrap() - 1;
987                        let first_layer_idx = key[..last_layer_idx].rfind('.').unwrap();
988                        let layer_n = key[first_layer_idx + 1..last_layer_idx]
989                            .parse::<usize>()
990                            .unwrap();
991                        layers_clone.contains(&layer_n) || layers_clone.is_empty()
992                    } else {
993                        false
994                    }
995                },
996                Arc::new(|_| DeviceForLoadTensor::Base),
997            )?;
998            vbs.push(vb);
999        }
1000
1001        let gate_vb = if let Some(gate_model_id) = gate_model_id {
1002            let model_id_str = &gate_model_id;
1003            let model_id = Path::new(&gate_model_id);
1004
1005            let api = {
1006                let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1007                let mut api = ApiBuilder::from_cache(cache)
1008                    .with_progress(!silent)
1009                    .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1010                if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1011                    api = api.with_cache_dir(x.into());
1012                }
1013                api.build().map_err(candle_core::Error::msg)?
1014            };
1015            let revision = revision.clone().unwrap_or("main".to_string());
1016            let api = api.repo(Repo::with_revision(
1017                model_id_str.clone(),
1018                RepoType::Model,
1019                revision.clone(),
1020            ));
1021
1022            let mut gate_filenames = vec![];
1023            for rfilename in
1024                api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1025            {
1026                gate_filenames.push(api_get_file!(api, &rfilename, model_id));
1027            }
1028            assert_eq!(
1029                gate_filenames.len(),
1030                1,
1031                "Gate model ID must contain only one .safetensors file"
1032            );
1033
1034            let vb = from_mmaped_safetensors(
1035                gate_filenames.clone(),
1036                vec![],
1037                Some(dtype),
1038                dev,
1039                vec![None],
1040                silent,
1041                None,
1042                |_| true,
1043                Arc::new(|_| DeviceForLoadTensor::Base),
1044            )?;
1045            info!(
1046                "Loaded gating layers from `{}`",
1047                gate_filenames[0].display()
1048            );
1049            Some(vb)
1050        } else {
1051            None
1052        };
1053
1054        self.model
1055            .create_anymoe_layers(vbs, config, (prefix, mlp), layers, expert_type, gate_vb)
1056    }
1057    fn amoe_supported(&self) -> bool {
1058        self.model.amoe_supported()
1059    }
1060}