mistralrs_core/pipeline/
vision.rs

1use super::isq::ImatrixDataSource;
2use super::isq::UqffFullSer;
3use super::{
4    get_model_paths, get_xlora_paths, AdapterKind, AnyMoePipelineMixin, AutoVisionLoader,
5    CacheManager, CacheManagerMixin, EitherCache, ForwardInputsResult, Gemma3Loader,
6    GeneralMetadata, IsqPipelineMixin, Loader, MetadataMixin, MiniCpmOLoader, ModelCategory,
7    ModelKind, ModelPaths, MultimodalPromptPrefixer, Phi4MMLoader, PreProcessingMixin, Processor,
8    Qwen2VLLoader, Qwen3VLLoader, TokenSource, VLlama4Loader, VLlamaLoader, VisionModel,
9    VisionModelLoader,
10};
11use super::{
12    Gemma3nLoader, Idefics2Loader, Idefics3Loader, LLaVALoader, LLaVANextLoader, Mistral3Loader,
13    Phi3VLoader, Qwen2_5VLLoader, VisionLoaderType,
14};
15use crate::attention::ATTENTION_CHUNK_SIZE;
16use crate::device_map::{self, DeviceMapper};
17use crate::distributed::{self, WorkerTransferData};
18use crate::kv_cache::{FullCacheManager, NormalCacheManager};
19use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine};
20use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
21use crate::pipeline::llg::build_llg_factory;
22use crate::pipeline::loaders::auto_device_map;
23use crate::pipeline::loaders::QuantizationConfigShim;
24use crate::pipeline::sampling::sample_and_add_toks;
25use crate::pipeline::text_models_inputs_processor::make_prompt_chunk;
26use crate::pipeline::{get_chat_template, ChatTemplate, IsqOrganization, LocalModelPaths};
27use crate::prefix_cacher::PrefixCacheManagerV2;
28use crate::sequence::Sequence;
29use crate::utils::tokenizer::get_tokenizer;
30use crate::utils::varbuilder_utils::DeviceForLoadTensor;
31use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
32use crate::vision_models::preprocessor_config::PreProcessorConfig;
33use crate::vision_models::processor_config::ProcessorConfig;
34use crate::vision_models::ModelInputs;
35use crate::{
36    api_dir_list, api_get_file, get_paths, get_uqff_paths, vision_normal_model_loader,
37    vision_normal_model_loader_sharded, AnyMoeExpertType, DeviceMapSetting, Ordering,
38    PagedAttentionConfig, Pipeline, Topology, TryIntoDType, GLOBAL_HF_CACHE,
39};
40use anyhow::Result;
41use candle_core::{Device, Tensor, Var};
42use hf_hub::Cache;
43use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
44use indicatif::MultiProgress;
45use mistralrs_quant::log::once_log_info;
46use mistralrs_quant::{
47    AfqLayer, GgufMatMul, HqqLayer, ImmediateIsqOverride, IsqType, QuantizedSerdeType,
48};
49use rand_isaac::Isaac64Rng;
50use regex_automata::meta::Regex;
51use std::any::Any;
52use std::borrow::Cow;
53use std::path::{Path, PathBuf};
54use std::str::FromStr;
55use std::sync::{Arc, RwLock};
56use std::time::Instant;
57use std::{env, fs};
58use tokenizers::Tokenizer;
59use tokio::sync::Mutex;
60use tracing::{info, warn};
61
62pub struct VisionPipeline {
63    model: Box<dyn VisionModel + Send + Sync>,
64    tokenizer: Arc<Tokenizer>,
65    chat_template: Arc<ChatTemplate>,
66    model_id: String,
67    metadata: Arc<GeneralMetadata>,
68    processor: Arc<dyn Processor + Send + Sync>,
69    preprocessor_config: Arc<PreProcessorConfig>,
70    topology: Option<Topology>,
71    silent: bool,
72    prefixer: Arc<dyn MultimodalPromptPrefixer>,
73    mapper: Box<dyn DeviceMapper + Send + Sync>,
74
75    // For full UQFF serialization
76    template_filename: Option<PathBuf>,
77    generation_config: Option<PathBuf>,
78    config: String,
79    processor_filename: Option<PathBuf>,
80    preprocessor_filename: Option<PathBuf>,
81    imatrix: Option<PathBuf>,
82}
83
84/// A loader for a vision (non-quantized) model.
85pub struct VisionLoader {
86    inner: Box<dyn VisionModelLoader>,
87    model_id: String,
88    config: VisionSpecificConfig,
89    kind: ModelKind,
90    chat_template: Option<String>,
91    tokenizer_json: Option<String>,
92    xlora_model_id: Option<String>,
93    xlora_order: Option<Ordering>,
94    token_source: RwLock<Option<TokenSource>>,
95    revision: RwLock<Option<String>>,
96    from_uqff: RwLock<Option<Vec<PathBuf>>>,
97    jinja_explicit: Option<String>,
98    hf_cache_path: Option<PathBuf>,
99    lora_adapter_ids: Option<Vec<String>>,
100}
101
102#[derive(Default)]
103/// A builder for a loader for a vision (non-quantized) model.
104pub struct VisionLoaderBuilder {
105    model_id: Option<String>,
106    config: VisionSpecificConfig,
107    kind: ModelKind,
108    chat_template: Option<String>,
109    tokenizer_json: Option<String>,
110    jinja_explicit: Option<String>,
111    hf_cache_path: Option<PathBuf>,
112    lora_adapter_ids: Option<Vec<String>>,
113}
114
115#[derive(Clone, Default)]
116/// Config specific to loading a vision model.
117pub struct VisionSpecificConfig {
118    pub topology: Option<Topology>,
119    pub write_uqff: Option<PathBuf>,
120    pub from_uqff: Option<Vec<PathBuf>>,
121    pub max_edge: Option<u32>,
122    pub imatrix: Option<PathBuf>,
123    pub calibration_file: Option<PathBuf>,
124    pub hf_cache_path: Option<PathBuf>,
125    pub matformer_config_path: Option<PathBuf>,
126    pub matformer_slice_name: Option<String>,
127}
128
129impl VisionLoaderBuilder {
130    pub fn new(
131        config: VisionSpecificConfig,
132        chat_template: Option<String>,
133        tokenizer_json: Option<String>,
134        model_id: Option<String>,
135        jinja_explicit: Option<String>,
136    ) -> Self {
137        Self {
138            config,
139            chat_template,
140            tokenizer_json,
141            model_id,
142            jinja_explicit,
143            kind: ModelKind::Normal,
144            hf_cache_path: None,
145            ..Default::default()
146        }
147    }
148
149    pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
150        self.hf_cache_path = Some(hf_cache_path);
151        self
152    }
153
154    pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
155        self.kind = ModelKind::Adapter {
156            adapter: AdapterKind::Lora,
157        };
158        self.lora_adapter_ids = Some(lora_adapter_ids);
159        self
160    }
161
162    pub fn build(self, loader: Option<VisionLoaderType>) -> Box<dyn Loader> {
163        let loader: Box<dyn VisionModelLoader> = match loader {
164            Some(VisionLoaderType::Phi3V) => Box::new(Phi3VLoader),
165            Some(VisionLoaderType::Idefics2) => Box::new(Idefics2Loader),
166            Some(VisionLoaderType::LLaVANext) => Box::new(LLaVANextLoader),
167            Some(VisionLoaderType::LLaVA) => Box::new(LLaVALoader),
168            Some(VisionLoaderType::VLlama) => Box::new(VLlamaLoader),
169            Some(VisionLoaderType::Qwen2VL) => Box::new(Qwen2VLLoader),
170            Some(VisionLoaderType::Idefics3) => Box::new(Idefics3Loader),
171            Some(VisionLoaderType::MiniCpmO) => Box::new(MiniCpmOLoader),
172            Some(VisionLoaderType::Phi4MM) => Box::new(Phi4MMLoader),
173            Some(VisionLoaderType::Qwen2_5VL) => Box::new(Qwen2_5VLLoader),
174            Some(VisionLoaderType::Gemma3) => Box::new(Gemma3Loader),
175            Some(VisionLoaderType::Mistral3) => Box::new(Mistral3Loader),
176            Some(VisionLoaderType::Llama4) => Box::new(VLlama4Loader),
177            Some(VisionLoaderType::Gemma3n) => Box::new(Gemma3nLoader),
178            Some(VisionLoaderType::Qwen3VL) => Box::new(Qwen3VLLoader),
179            None => Box::new(AutoVisionLoader),
180        };
181        Box::new(VisionLoader {
182            inner: loader,
183            model_id: self.model_id.unwrap(),
184            config: self.config,
185            kind: self.kind,
186            chat_template: self.chat_template,
187            tokenizer_json: self.tokenizer_json,
188            xlora_model_id: None,
189            xlora_order: None,
190            jinja_explicit: self.jinja_explicit,
191            token_source: RwLock::new(None),
192            revision: RwLock::new(None),
193            from_uqff: RwLock::new(None),
194            hf_cache_path: self.hf_cache_path,
195            lora_adapter_ids: self.lora_adapter_ids,
196        })
197    }
198}
199
200impl Loader for VisionLoader {
201    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
202    fn load_model_from_hf(
203        &self,
204        revision: Option<String>,
205        token_source: TokenSource,
206        dtype: &dyn TryIntoDType,
207        device: &Device,
208        silent: bool,
209        mapper: DeviceMapSetting,
210        in_situ_quant: Option<IsqType>,
211        paged_attn_config: Option<PagedAttentionConfig>,
212    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
213        let cache = self
214            .hf_cache_path
215            .clone()
216            .map(Cache::new)
217            .unwrap_or_default();
218        GLOBAL_HF_CACHE.get_or_init(|| cache);
219
220        let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths!(
221            LocalModelPaths,
222            &token_source,
223            revision.clone(),
224            self,
225            None,
226            None,
227            silent,
228            self.config.from_uqff.is_some()
229        );
230        if let Some(from_uqff) = self.config.from_uqff.clone() {
231            *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
232        }
233        *self
234            .token_source
235            .write()
236            .expect("Failed to write to token source") = Some(token_source);
237        *self.revision.write().expect("Failed to write to revision") = revision;
238        self.load_model_from_path(
239            &paths?,
240            dtype,
241            device,
242            silent,
243            mapper,
244            in_situ_quant,
245            paged_attn_config,
246        )
247    }
248
249    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
250    fn load_model_from_path(
251        &self,
252        paths: &Box<dyn ModelPaths>,
253        dtype: &dyn TryIntoDType,
254        device: &Device,
255        silent: bool,
256        mut mapper: DeviceMapSetting,
257        mut in_situ_quant: Option<IsqType>,
258        mut paged_attn_config: Option<PagedAttentionConfig>,
259    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
260        let config = std::fs::read_to_string(paths.get_config_filename())?;
261
262        if !self.inner.supports_paged_attention(&config) {
263            paged_attn_config = None;
264        }
265
266        info!("Prompt chunk size is {ATTENTION_CHUNK_SIZE}.");
267
268        let use_nccl = mistralrs_quant::distributed::use_nccl();
269
270        let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
271            let payload: WorkerTransferData = serde_json::from_str(&payload)?;
272            let WorkerTransferData::Init { id: _, worker_rank } = payload;
273            vec![candle_core::Device::new_cuda_with_stream(worker_rank + 1)?]
274        } else if use_nccl {
275            vec![candle_core::Device::new_cuda_with_stream(0)?]
276        } else {
277            device_map::get_all_similar_devices(device)?
278        };
279        #[cfg(feature = "cuda")]
280        for device in &available_devices {
281            if let Device::Cuda(dev) = device {
282                unsafe { dev.disable_event_tracking() };
283            }
284        }
285        let device = if use_nccl {
286            available_devices[0].clone()
287        } else {
288            device.clone()
289        };
290
291        // Load matformer slicing config if provided
292        let matformer_slicing_config = if let Some(matformer_path) =
293            &self.config.matformer_config_path
294        {
295            use crate::matformer::{MatformerConfig, MatformerSliceConfig};
296            info!("Loading Matformer config from {:?}", matformer_path);
297            let config = Arc::new(MatformerConfig::from_file(matformer_path)?);
298
299            if let Some(slice_name) = &self.config.matformer_slice_name {
300                info!("Using Matformer slice: {}", slice_name);
301                Some(MatformerSliceConfig::new(slice_name.clone(), config))
302            } else {
303                // If no slice name is provided but config exists, we'll need to handle this
304                // For now, return None and let the model handle the default slice selection
305                warn!("Matformer config loaded but no slice name specified. Models will use their default slice.");
306                None
307            }
308        } else {
309            None
310        };
311
312        // If auto, convert to Map if not using nccl
313        if use_nccl {
314            mapper = DeviceMapSetting::DummyNccl {
315                nm_device: available_devices[0].clone(),
316            };
317        } else if let DeviceMapSetting::Auto(mut params) = mapper.clone() {
318            // We can promote to vision params if we get text params
319            params = params.maybe_promote_to_vision();
320
321            // Initial dtype
322            let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
323
324            // Disable ISQ if we are loading a prequantized model.
325            if QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)? != 1 {
326                in_situ_quant = None;
327            }
328
329            // ISQ or UQFF: quantized path
330            // Match logic below where UQFF has priority
331            let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
332                if let Some(serialized) = &*self.from_uqff.read().unwrap() {
333                    let weight_pack_factor = {
334                        let ser_artifacts = unsafe {
335                            candle_core::safetensors::MmapedSafetensors::multi(serialized)?
336                        };
337                        let mut total_pack_factors = 0;
338                        let total_tensors = ser_artifacts.tensors().len();
339                        for (_, artifact) in ser_artifacts.tensors() {
340                            let artifact = artifact.data();
341                            // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
342                            let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
343                            let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
344                            {
345                                QuantizedSerdeType::Hqq => {
346                                    HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
347                                        .pack_factor(dtype)
348                                }
349                                QuantizedSerdeType::Gguf => {
350                                    GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
351                                        .pack_factor(dtype)
352                                }
353                                QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
354                                QuantizedSerdeType::Unquant => 1,
355                                QuantizedSerdeType::Afq => {
356                                    AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
357                                        .pack_factor(dtype)
358                                }
359                            };
360                            total_pack_factors += pack_factor;
361                        }
362
363                        total_pack_factors / total_tensors
364                    };
365
366                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
367                        &config,
368                        dtype,
369                        weight_pack_factor,
370                        matformer_slicing_config.as_ref(),
371                    )?;
372                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
373                        &config,
374                        dtype,
375                        weight_pack_factor,
376                        matformer_slicing_config.as_ref(),
377                    )?;
378                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
379                    (
380                        layer_sizes_in_bytes,
381                        non_mapped_size_in_bytes,
382                        layer_sizes_sum + non_mapped_size_in_bytes,
383                    )
384                } else if let Some(isq) = in_situ_quant {
385                    let weight_pack_factor = isq.pack_factor(dtype);
386                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
387                        &config,
388                        dtype,
389                        weight_pack_factor,
390                        matformer_slicing_config.as_ref(),
391                    )?;
392                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
393                        &config,
394                        dtype,
395                        weight_pack_factor,
396                        matformer_slicing_config.as_ref(),
397                    )?;
398                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
399                    (
400                        layer_sizes_in_bytes,
401                        non_mapped_size_in_bytes,
402                        layer_sizes_sum + non_mapped_size_in_bytes,
403                    )
404                } else {
405                    // Be sure to get the weight pack factor here; we might be loading a prequantized model.
406                    let weight_pack_factor =
407                        QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)?;
408                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
409                        &config,
410                        dtype,
411                        weight_pack_factor,
412                        matformer_slicing_config.as_ref(),
413                    )?;
414                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
415                        &config,
416                        dtype,
417                        weight_pack_factor,
418                        matformer_slicing_config.as_ref(),
419                    )?;
420                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
421                    (
422                        layer_sizes_in_bytes,
423                        non_mapped_size_in_bytes,
424                        layer_sizes_sum + non_mapped_size_in_bytes,
425                    )
426                };
427
428            let new = auto_device_map::get_device_layers(
429                &*self.inner,
430                &config,
431                self.inner.num_layers(&config)?,
432                layer_sizes_in_bytes,
433                non_mapped_size_in_bytes,
434                total_model_size_in_bytes,
435                &available_devices,
436                dtype,
437                &params,
438                paged_attn_config.as_ref(),
439            )?;
440            mapper = DeviceMapSetting::Map(new);
441        }
442
443        let pipeline_mapper = mapper.into_mapper(
444            self.inner.num_layers(&config)?,
445            &device,
446            self.config.topology.as_ref(),
447        )?;
448        let mapper = mapper.into_mapper(
449            self.inner.num_layers(&config)?,
450            &device,
451            self.config.topology.as_ref(),
452        )?;
453        let mut layer_devices = Vec::new();
454        for layer in 0..self.inner.num_layers(&config)? {
455            let device = mapper.device_for(layer, false).cloned();
456            layer_devices.push(device);
457        }
458        let dtype = mapper.get_min_dtype(dtype)?;
459
460        // TODO: PagedAttention is not supported with CPU for now.
461        // This check is not really necessary because `get_device_layers` should prevent it.
462        let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
463        if mapping_uses_cpu && paged_attn_config.is_some() {
464            warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
465            paged_attn_config = None;
466        }
467
468        info!("Model config: {:?}", self.inner.get_config_repr(&config)?);
469        if crate::using_flash_attn() {
470            once_log_info("FlashAttention is enabled.");
471        }
472
473        let topology_overrides = self
474            .config
475            .topology
476            .as_ref()
477            .map(|topology| {
478                topology
479                    .pattern_overrides()
480                    .into_iter()
481                    .map(|(regex, layer)| ImmediateIsqOverride {
482                        predicate: regex,
483                        ty: layer.isq,
484                        device: layer.device.clone(),
485                    })
486                    .collect::<Vec<_>>()
487            })
488            .unwrap_or_default();
489        let has_override_isq = topology_overrides
490            .iter()
491            .any(|override_entry| override_entry.ty.is_some());
492        let topology_requires_post_quant = self
493            .config
494            .topology
495            .as_ref()
496            .is_some_and(|topology| topology.requires_post_quantization());
497
498        let allow_immediate_cli = self.config.imatrix.is_none()
499            && self.config.calibration_file.is_none()
500            && !device.is_cuda()
501            && in_situ_quant.is_some();
502
503        let mut immediate_ty = None;
504        let mut immediate_predicates = Vec::new();
505        if allow_immediate_cli {
506            immediate_ty = in_situ_quant;
507            immediate_predicates = self.inner.immediate_isq_predicates(&config)?;
508            info!("Applying ISQ to {in_situ_quant:?}");
509            if immediate_predicates.is_empty() {
510                warn!("No predicates for this model and ISQ setting detected. ISQ will not be applied to any weights!");
511            }
512        }
513
514        let use_immediate = allow_immediate_cli || has_override_isq;
515        if use_immediate {
516            mistralrs_quant::set_immediate_isq_with_overrides(
517                immediate_ty,
518                immediate_predicates.clone(),
519                topology_overrides.clone(),
520            );
521        }
522
523        // Logic for ISQ here: if no calibration (i.e imatrix), then allow immediate ISQ. Otherwise, back to normal.
524        let mut loading_isq = if use_immediate {
525            false
526        } else {
527            in_situ_quant.is_some()
528        };
529        if self.config.imatrix.is_some() || self.config.calibration_file.is_some() {
530            loading_isq = true;
531        }
532        loading_isq |= topology_requires_post_quant;
533
534        if self.config.imatrix.is_some() && self.config.calibration_file.is_some() {
535            anyhow::bail!(
536                "`imatrix` and `calibration_file` were both specified, this is not allowed."
537            );
538        }
539
540        // Load onto the regular device if not using isq or if the calibration file is specified
541        let load_device = if !loading_isq || self.config.calibration_file.is_some() {
542            loading_isq = false;
543            device.clone()
544        } else {
545            Device::Cpu
546        };
547
548        let attention_mechanism = if paged_attn_config.is_some() {
549            AttentionImplementation::PagedAttention
550        } else {
551            AttentionImplementation::Eager
552        };
553
554        let multi_progress = Arc::new(MultiProgress::new());
555
556        let mut model = if use_nccl {
557            let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
558                dtype,
559                &device,
560                &available_devices,
561                silent,
562                &config,
563                loading_isq,
564                self.config.from_uqff.is_some(),
565                IsqOrganization::Default,
566                &*self.inner,
567                paths.as_ref(),
568            )?;
569
570            // Special case for where things can be more optimially loaded.
571            match self.kind {
572                ModelKind::Normal => vision_normal_model_loader_sharded!(
573                    sharded_vb,
574                    config,
575                    self.inner,
576                    mapper,
577                    loading_isq,
578                    device.clone(),
579                    attention_mechanism,
580                    multi_progress.clone(),
581                    matformer_slicing_config.clone(),
582                ),
583                _ => unreachable!(),
584            }
585        } else {
586            match self.kind {
587                ModelKind::Normal => vision_normal_model_loader!(
588                    paths,
589                    Some(dtype),
590                    &load_device,
591                    layer_devices.clone(),
592                    config,
593                    self.inner,
594                    silent,
595                    mapper,
596                    loading_isq,
597                    self.config.from_uqff.is_some(),
598                    device.clone(),
599                    attention_mechanism,
600                    multi_progress,
601                    matformer_slicing_config.clone(),
602                ),
603                _ => unreachable!(),
604            }
605        };
606
607        // Handle the Gemma 3 1b case here
608        let preprocessor_config: PreProcessorConfig = match paths.get_preprocessor_config().as_ref()
609        {
610            Some(preprocessor_config) => {
611                serde_json::from_str(&fs::read_to_string(preprocessor_config).unwrap()).unwrap()
612            }
613            None => PreProcessorConfig::default(),
614        };
615        let processor_config: Option<ProcessorConfig> = paths
616            .get_processor_config()
617            .as_ref()
618            .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
619
620        let processor = self.inner.get_processor(
621            &config,
622            processor_config,
623            preprocessor_config.clone(),
624            self.config.max_edge,
625        ); //There are always some repos that don't properly handle config position, for example... LLaVA
626
627        let tokenizer = get_tokenizer(
628            paths.get_tokenizer_filename(),
629            Some(processor.get_special_tokens()),
630        )?;
631
632        let gen_conf: Option<GenerationConfig> = paths
633            .get_gen_conf_filename()
634            .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
635        let chat_template_explicit = paths
636            .get_chat_template_explicit()
637            .as_ref()
638            .map(|x| x.to_string_lossy().to_string());
639        let chat_template = get_chat_template(
640            paths,
641            self.jinja_explicit.as_ref(),
642            chat_template_explicit.as_ref(),
643            self.chat_template.as_ref(),
644            None,
645        );
646
647        if let Some(calibration_file) = &self.config.calibration_file {
648            let calibration_data = std::fs::read_to_string(calibration_file)?;
649            // Tokenize, don't add bos yet
650            let tokens = tokenizer
651                .encode_fast(calibration_data, false)
652                .map_err(anyhow::Error::msg)?
653                .get_ids()
654                .to_vec();
655            info!(
656                "Collecting imatrix from calibration file `{}` of {} tokens.",
657                calibration_file.display(),
658                tokens.len()
659            );
660            let bos_toks = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();
661            let bos_tok_id = tokenizer
662                .token_to_id(&bos_toks[0])
663                .expect("Somehow the bos token is not present.");
664
665            // NOTE: We ONLY calibrate the text bits of these models!!
666            // So only those should be tracked!
667            model.begin_track_stats()?;
668
669            const CHUNK_SIZE: usize = 1024;
670            let n_chunks: usize = tokens.len().div_ceil(CHUNK_SIZE);
671            let start = Instant::now();
672            for (i, chunk) in tokens.chunks(CHUNK_SIZE).enumerate() {
673                let chunk = [vec![bos_tok_id], chunk.to_vec()].concat();
674                let chunk_len = chunk.len();
675
676                let start = Instant::now();
677                let inputs = make_prompt_chunk(
678                    0,
679                    vec![&chunk],
680                    &[0],
681                    &load_device,
682                    None,
683                    false,
684                    None,
685                    None,
686                )?;
687                let _ = model.forward(
688                    &inputs.input,
689                    None, // NOTE: We ONLY calibrate the text bits of these models!!
690                    &inputs.positions,
691                    inputs.context_lens,
692                    inputs.position_ids,
693                    model.default_model_specific_args(&inputs.input),
694                    None,
695                    &inputs.flash_meta,
696                )?;
697                match model.cache_mut() {
698                    EitherCache::Full(full) => {
699                        for layer in &mut *full.lock() {
700                            *layer = None
701                        }
702                    }
703                    EitherCache::Normal(normal) => {
704                        for layer in &mut *normal.lock().unwrap().0 {
705                            layer.reset();
706                        }
707                    }
708                }
709                let end = Instant::now();
710                info!(
711                    "Processed chunk {}/{n_chunks} ({chunk_len} tokens), {:.2}s",
712                    i + 1,
713                    end.duration_since(start).as_secs_f32()
714                );
715            }
716            load_device.synchronize()?;
717            let end = Instant::now();
718            info!(
719                "Finished collecting imatrix in {:.2}s",
720                end.duration_since(start).as_secs_f32()
721            );
722        }
723
724        let should_serialize = self.config.write_uqff.is_some();
725        let should_quantize_pass = loading_isq;
726
727        if (should_quantize_pass || should_serialize) && self.config.from_uqff.is_none() {
728            let imatrix_source = if should_quantize_pass {
729                match (
730                    self.config.imatrix.as_ref(),
731                    self.config.calibration_file.is_some(),
732                ) {
733                    (None, false) => None,
734                    (Some(file), false) => Some(ImatrixDataSource::File(file)),
735                    (None, true) => Some(ImatrixDataSource::Collected),
736                    (Some(_), true) => unreachable!(),
737                }
738            } else {
739                None
740            };
741            if should_quantize_pass {
742                info!("Applying ISQ to all ranks.");
743            } else {
744                info!("Serializing existing ISQ tensors without additional quantization.");
745            }
746            model.quantize(
747                in_situ_quant,
748                device.clone(),
749                self.config.topology.as_ref(),
750                silent,
751                imatrix_source,
752                IsqOrganization::Default,
753                should_quantize_pass,
754                self.config.write_uqff.as_ref(),
755                UqffFullSer {
756                    tokenizer: &tokenizer,
757                    template_filename: paths.get_template_filename(),
758                    generation_config: paths.get_gen_conf_filename(),
759                    config: config.clone(),
760                    processor_filename: paths.get_processor_config(),
761                    preprocessor_filename: paths.get_preprocessor_config(),
762                    modules: None,
763                    module_paths: None,
764                },
765                Arc::new(MultiProgress::new()),
766            )?;
767        } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
768            model.load_from_artifacts(
769                device.clone(),
770                self.config.topology.as_ref(),
771                silent,
772                from_uqff,
773            )?;
774        }
775
776        let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
777            anyhow::ensure!(
778                !matches!(self.kind, ModelKind::Adapter { .. }),
779                "PagedAttention does not support adapter models."
780            );
781            let cache_config = calculate_cache_config(
782                paged_attn_config.mem_gpu,
783                paged_attn_config.block_size,
784                dtype,
785                paged_attn_config.cache_type,
786                model.config(),
787                &device,
788                &layer_devices,
789                silent,
790            )?;
791            let cache_engine =
792                CacheEngine::new(model.config(), &cache_config, dtype, &device, layer_devices)?;
793            (Some(cache_config), Some(cache_engine))
794        } else {
795            (None, None)
796        };
797
798        let max_seq_len = model.max_seq_len();
799        let llg_factory = build_llg_factory(tokenizer.clone())?;
800        let num_hidden_layers = match model.cache() {
801            EitherCache::Full(full) => full.lock().len(),
802            EitherCache::Normal(normal) => normal.lock().unwrap().0.len(),
803        };
804        let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
805        let sliding_window = model.config().sliding_window;
806        let model_metadata = Arc::new(model.config().clone());
807        Ok(Arc::new(Mutex::new(VisionPipeline {
808            model,
809            tokenizer: tokenizer.into(),
810            chat_template: Arc::new(chat_template),
811            model_id: self.model_id.clone(),
812            metadata: Arc::new(GeneralMetadata {
813                max_seq_len,
814                llg_factory: Some(llg_factory),
815                is_xlora: false,
816                num_hidden_layers,
817                eos_tok: eos,
818                kind: self.kind.clone(),
819                no_kv_cache: false,
820                no_prefix_cache: !self.inner.supports_prefix_cacher(&config),
821                activation_dtype: dtype,
822                sliding_window,
823                cache_config,
824                cache_engine,
825                model_metadata: Some(model_metadata),
826                modalities: self.inner.modalities(&config)?,
827            }),
828            processor,
829            prefixer: self.inner.prefixer(&config),
830            preprocessor_config: Arc::new(preprocessor_config),
831            topology: self.config.topology.clone(),
832            silent,
833            template_filename: paths.get_template_filename().clone(),
834            generation_config: paths.get_gen_conf_filename().cloned(),
835            config,
836            processor_filename: paths.get_processor_config().clone(),
837            preprocessor_filename: paths.get_preprocessor_config().clone(),
838            mapper: pipeline_mapper,
839            imatrix: self.config.imatrix.clone(),
840        })))
841    }
842
843    fn get_id(&self) -> String {
844        self.model_id.to_string()
845    }
846
847    fn get_kind(&self) -> ModelKind {
848        self.kind.clone()
849    }
850}
851
852impl PreProcessingMixin for VisionPipeline {
853    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
854        Some(self.chat_template.clone())
855    }
856    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
857        Some(self.preprocessor_config.clone())
858    }
859    fn get_processor(&self) -> Arc<dyn super::Processor> {
860        self.processor.clone()
861    }
862}
863
864impl IsqPipelineMixin for VisionPipeline {
865    fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
866        let device = self.device().clone();
867        self.model
868            .quantize(
869                Some(dtype),
870                device,
871                self.topology.as_ref(),
872                self.silent,
873                self.imatrix.as_ref().map(ImatrixDataSource::File),
874                IsqOrganization::Default,
875                true,
876                None,
877                UqffFullSer {
878                    tokenizer: &self.tokenizer,
879                    template_filename: &self.template_filename,
880                    generation_config: self.generation_config.as_ref(),
881                    config: self.config.clone(),
882                    processor_filename: &self.processor_filename,
883                    preprocessor_filename: &self.preprocessor_filename,
884                    modules: None,
885                    module_paths: None,
886                },
887                Arc::new(MultiProgress::new()),
888            )
889            .map_err(anyhow::Error::msg)
890    }
891}
892
893impl CacheManagerMixin for VisionPipeline {
894    fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
895        if matches!(self.model.cache(), EitherCache::Full(_)) {
896            FullCacheManager.clone_in_cache(self, seqs, false)
897        } else {
898            NormalCacheManager.clone_in_cache(self, seqs, false)
899        }
900    }
901    fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
902        if matches!(self.model.cache(), EitherCache::Full(_)) {
903            FullCacheManager.clone_out_cache(self, seqs, false)
904        } else {
905            NormalCacheManager.clone_out_cache(self, seqs, false)
906        }
907    }
908    fn set_none_cache(
909        &self,
910        seqs: &mut [&mut Sequence],
911        reset_non_granular: bool,
912        modify_draft_cache: bool,
913
914        load_preallocated_cache: bool,
915    ) {
916        if matches!(self.model.cache(), EitherCache::Full(_)) {
917            FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false);
918        } else {
919            NormalCacheManager.set_none_cache(
920                self,
921                seqs,
922                modify_draft_cache,
923                load_preallocated_cache,
924            );
925        }
926        if reset_non_granular {
927            self.reset_non_granular_state()
928        }
929    }
930    fn cache(&self) -> &EitherCache {
931        self.model.cache()
932    }
933}
934
935impl MetadataMixin for VisionPipeline {
936    fn device(&self) -> Device {
937        self.model.device().clone()
938    }
939    fn get_metadata(&self) -> Arc<GeneralMetadata> {
940        self.metadata.clone()
941    }
942    fn name(&self) -> String {
943        self.model_id.clone()
944    }
945    fn reset_non_granular_state(&self) {}
946    fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
947        Some(self.tokenizer.clone())
948    }
949    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
950        Some(&*self.mapper)
951    }
952}
953
954#[async_trait::async_trait]
955impl Pipeline for VisionPipeline {
956    fn forward_inputs(
957        &mut self,
958        inputs: Box<dyn Any>,
959        return_raw_logits: bool,
960    ) -> candle_core::Result<ForwardInputsResult> {
961        let ModelInputs {
962            input_ids,
963            seqlen_offsets,
964            context_lens,
965            position_ids,
966            pixel_values,
967            model_specific_args,
968            paged_attn_meta,
969            flash_meta,
970        } = *inputs.downcast::<ModelInputs>().expect("Downcast failed.");
971        let metadata = self.get_metadata();
972        let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
973            (Some(engine), Some(meta)) => Some((engine.get_kv_cache().clone(), meta)),
974            (Some(_), None) => {
975                // This can happen if Rust-side user code is wrong
976                candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
977            }
978            (None, Some(_)) => {
979                // This should never happen but we handle it anyway
980                candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
981            }
982            (None, None) => None,
983        };
984        let logits = self.model.forward(
985            &input_ids,
986            pixel_values,
987            &seqlen_offsets,
988            context_lens,
989            position_ids,
990            model_specific_args,
991            paged_attn_meta,
992            &flash_meta,
993        )?;
994        if return_raw_logits {
995            Ok(ForwardInputsResult::RawLogits { logits })
996        } else {
997            Ok(ForwardInputsResult::CausalGeneration { logits })
998        }
999    }
1000    async fn sample_causal_gen(
1001        &self,
1002        seqs: &mut [&mut Sequence],
1003        logits: Vec<Tensor>,
1004        prefix_cacher: &mut PrefixCacheManagerV2,
1005        disable_eos_stop: bool,
1006        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
1007    ) -> Result<(), candle_core::Error> {
1008        sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
1009    }
1010    fn category(&self) -> ModelCategory {
1011        ModelCategory::Vision {
1012            prefixer: self.prefixer.clone(),
1013        }
1014    }
1015}
1016
1017impl AnyMoePipelineMixin for VisionPipeline {
1018    fn amoe_finish_training(&mut self, gate_model_id: Option<String>) -> candle_core::Result<()> {
1019        self.model.finish_training(gate_model_id)
1020    }
1021    fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
1022        self.model.get_vars()
1023    }
1024    fn amoe_base_model_trainable_params(&self) -> usize {
1025        self.model.trainable_params()
1026    }
1027    fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
1028        self.model.take_cached_gating_outputs()
1029    }
1030    fn amoe_create_layers(
1031        &mut self,
1032        model_ids: Vec<String>,
1033        token: &TokenSource,
1034        revision: Option<String>,
1035        match_regex: &str,
1036        config: crate::amoe::AnyMoeConfig,
1037        dtype: candle_core::DType,
1038        dev: &Device,
1039        (prefix, mlp): (String, String),
1040        layers: Vec<usize>,
1041        expert_type: AnyMoeExpertType,
1042        silent: bool,
1043        gate_model_id: Option<String>,
1044    ) -> candle_core::Result<()> {
1045        let mut vbs = Vec::new();
1046        // Precompile regex here
1047        let regex = Regex::new(match_regex).map_err(candle_core::Error::msg)?;
1048        for model_id in model_ids {
1049            let model_id_str = &model_id;
1050            let model_id = Path::new(&model_id);
1051
1052            let api = {
1053                let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1054                let mut api = ApiBuilder::from_cache(cache)
1055                    .with_progress(!silent)
1056                    .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1057                if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1058                    api = api.with_cache_dir(x.into());
1059                }
1060                api.build().map_err(candle_core::Error::msg)?
1061            };
1062            let revision = revision.clone().unwrap_or("main".to_string());
1063            let api = api.repo(Repo::with_revision(
1064                model_id_str.clone(),
1065                RepoType::Model,
1066                revision.clone(),
1067            ));
1068
1069            let mut filenames = vec![];
1070            for rfilename in
1071                api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1072            {
1073                filenames.push(api_get_file!(api, &rfilename, model_id));
1074            }
1075
1076            let regex = regex.clone();
1077            let match_regex_clone = match_regex.to_string();
1078            let layers_clone = layers.clone();
1079            let vb = from_mmaped_safetensors(
1080                filenames,
1081                vec![],
1082                Some(dtype),
1083                dev,
1084                vec![None],
1085                silent,
1086                None,
1087                move |key| {
1088                    if regex.is_match(&key) {
1089                        // Idx of the last char of the layer id, +1
1090                        // Assumes N.MLP
1091                        let last_layer_idx = key.find(&match_regex_clone).unwrap() - 1;
1092                        let first_layer_idx = key[..last_layer_idx].rfind('.').unwrap();
1093                        let layer_n = key[first_layer_idx + 1..last_layer_idx]
1094                            .parse::<usize>()
1095                            .unwrap();
1096                        layers_clone.contains(&layer_n) || layers_clone.is_empty()
1097                    } else {
1098                        false
1099                    }
1100                },
1101                Arc::new(|_| DeviceForLoadTensor::Base),
1102            )?;
1103            vbs.push(vb);
1104        }
1105
1106        let gate_vb = if let Some(gate_model_id) = gate_model_id {
1107            let model_id_str = &gate_model_id;
1108            let model_id = Path::new(&gate_model_id);
1109
1110            let api = {
1111                let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1112                let mut api = ApiBuilder::from_cache(cache)
1113                    .with_progress(!silent)
1114                    .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1115                if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1116                    api = api.with_cache_dir(x.into());
1117                }
1118                api.build().map_err(candle_core::Error::msg)?
1119            };
1120            let revision = revision.clone().unwrap_or("main".to_string());
1121            let api = api.repo(Repo::with_revision(
1122                model_id_str.clone(),
1123                RepoType::Model,
1124                revision.clone(),
1125            ));
1126
1127            let mut gate_filenames = vec![];
1128            for rfilename in
1129                api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1130            {
1131                gate_filenames.push(api_get_file!(api, &rfilename, model_id));
1132            }
1133            assert_eq!(
1134                gate_filenames.len(),
1135                1,
1136                "Gate model ID must contain only one .safetensors file"
1137            );
1138
1139            let vb = from_mmaped_safetensors(
1140                gate_filenames.clone(),
1141                vec![],
1142                Some(dtype),
1143                dev,
1144                vec![None],
1145                silent,
1146                None,
1147                |_| true,
1148                Arc::new(|_| DeviceForLoadTensor::Base),
1149            )?;
1150            info!(
1151                "Loaded gating layers from `{}`",
1152                gate_filenames[0].display()
1153            );
1154            Some(vb)
1155        } else {
1156            None
1157        };
1158
1159        self.model
1160            .create_anymoe_layers(vbs, config, (prefix, mlp), layers, expert_type, gate_vb)
1161    }
1162    fn amoe_supported(&self) -> bool {
1163        self.model.amoe_supported()
1164    }
1165}