mistralrs_core/pipeline/
vision.rs

1use super::isq::ImatrixDataSource;
2use super::isq::UqffFullSer;
3use super::{
4    get_model_paths, get_xlora_paths, AdapterKind, AnyMoePipelineMixin, AutoVisionLoader,
5    CacheManager, CacheManagerMixin, EitherCache, ForwardInputsResult, Gemma3Loader,
6    GeneralMetadata, IsqPipelineMixin, Loader, MetadataMixin, MiniCpmOLoader, ModelCategory,
7    ModelKind, ModelPaths, MultimodalPromptPrefixer, Phi4MMLoader, PreProcessingMixin, Processor,
8    Qwen2VLLoader, Qwen3VLLoader, TokenSource, VLlama4Loader, VLlamaLoader, VisionModel,
9    VisionModelLoader,
10};
11use super::{
12    Gemma3nLoader, Idefics2Loader, Idefics3Loader, LLaVALoader, LLaVANextLoader, Mistral3Loader,
13    Phi3VLoader, Qwen2_5VLLoader, VisionLoaderType,
14};
15use crate::attention::ATTENTION_CHUNK_SIZE;
16use crate::device_map::{self, DeviceMapper};
17use crate::distributed::{self, WorkerTransferData};
18use crate::kv_cache::{FullCacheManager, NormalCacheManager};
19use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine};
20use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
21use crate::pipeline::llg::build_llg_factory;
22use crate::pipeline::loaders::auto_device_map;
23use crate::pipeline::loaders::QuantizationConfigShim;
24use crate::pipeline::sampling::sample_and_add_toks;
25use crate::pipeline::text_models_inputs_processor::make_prompt_chunk;
26use crate::pipeline::{get_chat_template, ChatTemplate, IsqOrganization, LocalModelPaths};
27use crate::prefix_cacher::PrefixCacheManagerV2;
28use crate::sequence::Sequence;
29use crate::utils::tokenizer::get_tokenizer;
30use crate::utils::varbuilder_utils::DeviceForLoadTensor;
31use crate::utils::{
32    progress::{new_multi_progress, ProgressScopeGuard},
33    tokens::get_token,
34    varbuilder_utils::from_mmaped_safetensors,
35};
36use crate::vision_models::preprocessor_config::PreProcessorConfig;
37use crate::vision_models::processor_config::ProcessorConfig;
38use crate::vision_models::ModelInputs;
39use crate::{
40    api_dir_list, api_get_file, get_paths, get_uqff_paths, vision_normal_model_loader,
41    vision_normal_model_loader_sharded, AnyMoeExpertType, DeviceMapSetting, Ordering,
42    PagedAttentionConfig, Pipeline, Topology, TryIntoDType, GLOBAL_HF_CACHE,
43};
44use anyhow::Result;
45use candle_core::{Device, Tensor, Var};
46use hf_hub::Cache;
47use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
48use mistralrs_quant::log::once_log_info;
49use mistralrs_quant::{
50    AfqLayer, GgufMatMul, HqqLayer, ImmediateIsqOverride, IsqType, QuantizedSerdeType,
51};
52use rand_isaac::Isaac64Rng;
53use regex_automata::meta::Regex;
54use std::any::Any;
55use std::borrow::Cow;
56use std::path::{Path, PathBuf};
57use std::str::FromStr;
58use std::sync::{Arc, RwLock};
59use std::time::Instant;
60use std::{env, fs};
61use tokenizers::Tokenizer;
62use tokio::sync::Mutex;
63use tracing::{info, warn};
64
65pub struct VisionPipeline {
66    model: Box<dyn VisionModel + Send + Sync>,
67    tokenizer: Arc<Tokenizer>,
68    chat_template: Arc<ChatTemplate>,
69    model_id: String,
70    metadata: Arc<GeneralMetadata>,
71    processor: Arc<dyn Processor + Send + Sync>,
72    preprocessor_config: Arc<PreProcessorConfig>,
73    topology: Option<Topology>,
74    silent: bool,
75    prefixer: Arc<dyn MultimodalPromptPrefixer>,
76    mapper: Box<dyn DeviceMapper + Send + Sync>,
77
78    // For full UQFF serialization
79    template_filename: Option<PathBuf>,
80    generation_config: Option<PathBuf>,
81    config: String,
82    processor_filename: Option<PathBuf>,
83    preprocessor_filename: Option<PathBuf>,
84    imatrix: Option<PathBuf>,
85}
86
87/// A loader for a vision (non-quantized) model.
88pub struct VisionLoader {
89    inner: Box<dyn VisionModelLoader>,
90    model_id: String,
91    config: VisionSpecificConfig,
92    kind: ModelKind,
93    chat_template: Option<String>,
94    tokenizer_json: Option<String>,
95    xlora_model_id: Option<String>,
96    xlora_order: Option<Ordering>,
97    token_source: RwLock<Option<TokenSource>>,
98    revision: RwLock<Option<String>>,
99    from_uqff: RwLock<Option<Vec<PathBuf>>>,
100    jinja_explicit: Option<String>,
101    hf_cache_path: Option<PathBuf>,
102    lora_adapter_ids: Option<Vec<String>>,
103}
104
105#[derive(Default)]
106/// A builder for a loader for a vision (non-quantized) model.
107pub struct VisionLoaderBuilder {
108    model_id: Option<String>,
109    config: VisionSpecificConfig,
110    kind: ModelKind,
111    chat_template: Option<String>,
112    tokenizer_json: Option<String>,
113    jinja_explicit: Option<String>,
114    hf_cache_path: Option<PathBuf>,
115    lora_adapter_ids: Option<Vec<String>>,
116}
117
118#[derive(Clone, Default)]
119/// Config specific to loading a vision model.
120pub struct VisionSpecificConfig {
121    pub topology: Option<Topology>,
122    pub write_uqff: Option<PathBuf>,
123    pub from_uqff: Option<Vec<PathBuf>>,
124    pub max_edge: Option<u32>,
125    pub imatrix: Option<PathBuf>,
126    pub calibration_file: Option<PathBuf>,
127    pub hf_cache_path: Option<PathBuf>,
128    pub matformer_config_path: Option<PathBuf>,
129    pub matformer_slice_name: Option<String>,
130}
131
132impl VisionLoaderBuilder {
133    pub fn new(
134        config: VisionSpecificConfig,
135        chat_template: Option<String>,
136        tokenizer_json: Option<String>,
137        model_id: Option<String>,
138        jinja_explicit: Option<String>,
139    ) -> Self {
140        Self {
141            config,
142            chat_template,
143            tokenizer_json,
144            model_id,
145            jinja_explicit,
146            kind: ModelKind::Normal,
147            hf_cache_path: None,
148            ..Default::default()
149        }
150    }
151
152    pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
153        self.hf_cache_path = Some(hf_cache_path);
154        self
155    }
156
157    pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
158        self.kind = ModelKind::Adapter {
159            adapter: AdapterKind::Lora,
160        };
161        self.lora_adapter_ids = Some(lora_adapter_ids);
162        self
163    }
164
165    pub fn build(self, loader: Option<VisionLoaderType>) -> Box<dyn Loader> {
166        let loader: Box<dyn VisionModelLoader> = match loader {
167            Some(VisionLoaderType::Phi3V) => Box::new(Phi3VLoader),
168            Some(VisionLoaderType::Idefics2) => Box::new(Idefics2Loader),
169            Some(VisionLoaderType::LLaVANext) => Box::new(LLaVANextLoader),
170            Some(VisionLoaderType::LLaVA) => Box::new(LLaVALoader),
171            Some(VisionLoaderType::VLlama) => Box::new(VLlamaLoader),
172            Some(VisionLoaderType::Qwen2VL) => Box::new(Qwen2VLLoader),
173            Some(VisionLoaderType::Idefics3) => Box::new(Idefics3Loader),
174            Some(VisionLoaderType::MiniCpmO) => Box::new(MiniCpmOLoader),
175            Some(VisionLoaderType::Phi4MM) => Box::new(Phi4MMLoader),
176            Some(VisionLoaderType::Qwen2_5VL) => Box::new(Qwen2_5VLLoader),
177            Some(VisionLoaderType::Gemma3) => Box::new(Gemma3Loader),
178            Some(VisionLoaderType::Mistral3) => Box::new(Mistral3Loader),
179            Some(VisionLoaderType::Llama4) => Box::new(VLlama4Loader),
180            Some(VisionLoaderType::Gemma3n) => Box::new(Gemma3nLoader),
181            Some(VisionLoaderType::Qwen3VL) => Box::new(Qwen3VLLoader),
182            None => Box::new(AutoVisionLoader),
183        };
184        Box::new(VisionLoader {
185            inner: loader,
186            model_id: self.model_id.unwrap(),
187            config: self.config,
188            kind: self.kind,
189            chat_template: self.chat_template,
190            tokenizer_json: self.tokenizer_json,
191            xlora_model_id: None,
192            xlora_order: None,
193            jinja_explicit: self.jinja_explicit,
194            token_source: RwLock::new(None),
195            revision: RwLock::new(None),
196            from_uqff: RwLock::new(None),
197            hf_cache_path: self.hf_cache_path,
198            lora_adapter_ids: self.lora_adapter_ids,
199        })
200    }
201}
202
203impl Loader for VisionLoader {
204    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
205    fn load_model_from_hf(
206        &self,
207        revision: Option<String>,
208        token_source: TokenSource,
209        dtype: &dyn TryIntoDType,
210        device: &Device,
211        silent: bool,
212        mapper: DeviceMapSetting,
213        in_situ_quant: Option<IsqType>,
214        paged_attn_config: Option<PagedAttentionConfig>,
215    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
216        let _progress_guard = ProgressScopeGuard::new(silent);
217        let cache = self
218            .hf_cache_path
219            .clone()
220            .map(Cache::new)
221            .unwrap_or_default();
222        GLOBAL_HF_CACHE.get_or_init(|| cache);
223
224        let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths!(
225            LocalModelPaths,
226            &token_source,
227            revision.clone(),
228            self,
229            None,
230            None,
231            silent,
232            self.config.from_uqff.is_some()
233        );
234        if let Some(from_uqff) = self.config.from_uqff.clone() {
235            *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
236        }
237        *self
238            .token_source
239            .write()
240            .expect("Failed to write to token source") = Some(token_source);
241        *self.revision.write().expect("Failed to write to revision") = revision;
242        self.load_model_from_path(
243            &paths?,
244            dtype,
245            device,
246            silent,
247            mapper,
248            in_situ_quant,
249            paged_attn_config,
250        )
251    }
252
253    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
254    fn load_model_from_path(
255        &self,
256        paths: &Box<dyn ModelPaths>,
257        dtype: &dyn TryIntoDType,
258        device: &Device,
259        silent: bool,
260        mut mapper: DeviceMapSetting,
261        mut in_situ_quant: Option<IsqType>,
262        mut paged_attn_config: Option<PagedAttentionConfig>,
263    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
264        let _progress_guard = ProgressScopeGuard::new(silent);
265        let config = std::fs::read_to_string(paths.get_config_filename())?;
266
267        if !self.inner.supports_paged_attention(&config) {
268            paged_attn_config = None;
269        }
270
271        info!("Prompt chunk size is {ATTENTION_CHUNK_SIZE}.");
272
273        let use_nccl = mistralrs_quant::distributed::use_nccl();
274
275        let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
276            let payload: WorkerTransferData = serde_json::from_str(&payload)?;
277            let WorkerTransferData::Init { id: _, worker_rank } = payload;
278            vec![candle_core::Device::new_cuda_with_stream(worker_rank + 1)?]
279        } else if use_nccl {
280            vec![candle_core::Device::new_cuda_with_stream(0)?]
281        } else {
282            device_map::get_all_similar_devices(device)?
283        };
284        #[cfg(feature = "cuda")]
285        for device in &available_devices {
286            if let Device::Cuda(dev) = device {
287                unsafe { dev.disable_event_tracking() };
288            }
289        }
290        let device = if use_nccl {
291            available_devices[0].clone()
292        } else {
293            device.clone()
294        };
295
296        // Load matformer slicing config if provided
297        let matformer_slicing_config = if let Some(matformer_path) =
298            &self.config.matformer_config_path
299        {
300            use crate::matformer::{MatformerConfig, MatformerSliceConfig};
301            info!("Loading Matformer config from {:?}", matformer_path);
302            let config = Arc::new(MatformerConfig::from_file(matformer_path)?);
303
304            if let Some(slice_name) = &self.config.matformer_slice_name {
305                info!("Using Matformer slice: {}", slice_name);
306                Some(MatformerSliceConfig::new(slice_name.clone(), config))
307            } else {
308                // If no slice name is provided but config exists, we'll need to handle this
309                // For now, return None and let the model handle the default slice selection
310                warn!("Matformer config loaded but no slice name specified. Models will use their default slice.");
311                None
312            }
313        } else {
314            None
315        };
316
317        // If auto, convert to Map if not using nccl
318        if use_nccl {
319            mapper = DeviceMapSetting::DummyNccl {
320                nm_device: available_devices[0].clone(),
321            };
322        } else if let DeviceMapSetting::Auto(mut params) = mapper.clone() {
323            // We can promote to vision params if we get text params
324            params = params.maybe_promote_to_vision();
325
326            // Initial dtype
327            let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
328
329            // Disable ISQ if we are loading a prequantized model.
330            if QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)? != 1 {
331                in_situ_quant = None;
332            }
333
334            // ISQ or UQFF: quantized path
335            // Match logic below where UQFF has priority
336            let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
337                if let Some(serialized) = &*self.from_uqff.read().unwrap() {
338                    let weight_pack_factor = {
339                        let ser_artifacts = unsafe {
340                            candle_core::safetensors::MmapedSafetensors::multi(serialized)?
341                        };
342                        let mut total_pack_factors = 0;
343                        let total_tensors = ser_artifacts.tensors().len();
344                        for (_, artifact) in ser_artifacts.tensors() {
345                            let artifact = artifact.data();
346                            // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
347                            let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
348                            let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
349                            {
350                                QuantizedSerdeType::Hqq => {
351                                    HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
352                                        .pack_factor(dtype)
353                                }
354                                QuantizedSerdeType::Gguf => {
355                                    GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
356                                        .pack_factor(dtype)
357                                }
358                                QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
359                                QuantizedSerdeType::Unquant => 1,
360                                QuantizedSerdeType::Afq => {
361                                    AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
362                                        .pack_factor(dtype)
363                                }
364                            };
365                            total_pack_factors += pack_factor;
366                        }
367
368                        total_pack_factors / total_tensors
369                    };
370
371                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
372                        &config,
373                        dtype,
374                        weight_pack_factor,
375                        matformer_slicing_config.as_ref(),
376                    )?;
377                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
378                        &config,
379                        dtype,
380                        weight_pack_factor,
381                        matformer_slicing_config.as_ref(),
382                    )?;
383                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
384                    (
385                        layer_sizes_in_bytes,
386                        non_mapped_size_in_bytes,
387                        layer_sizes_sum + non_mapped_size_in_bytes,
388                    )
389                } else if let Some(isq) = in_situ_quant {
390                    let weight_pack_factor = isq.pack_factor(dtype);
391                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
392                        &config,
393                        dtype,
394                        weight_pack_factor,
395                        matformer_slicing_config.as_ref(),
396                    )?;
397                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
398                        &config,
399                        dtype,
400                        weight_pack_factor,
401                        matformer_slicing_config.as_ref(),
402                    )?;
403                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
404                    (
405                        layer_sizes_in_bytes,
406                        non_mapped_size_in_bytes,
407                        layer_sizes_sum + non_mapped_size_in_bytes,
408                    )
409                } else {
410                    // Be sure to get the weight pack factor here; we might be loading a prequantized model.
411                    let weight_pack_factor =
412                        QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)?;
413                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
414                        &config,
415                        dtype,
416                        weight_pack_factor,
417                        matformer_slicing_config.as_ref(),
418                    )?;
419                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
420                        &config,
421                        dtype,
422                        weight_pack_factor,
423                        matformer_slicing_config.as_ref(),
424                    )?;
425                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
426                    (
427                        layer_sizes_in_bytes,
428                        non_mapped_size_in_bytes,
429                        layer_sizes_sum + non_mapped_size_in_bytes,
430                    )
431                };
432
433            let new = auto_device_map::get_device_layers(
434                &*self.inner,
435                &config,
436                self.inner.num_layers(&config)?,
437                layer_sizes_in_bytes,
438                non_mapped_size_in_bytes,
439                total_model_size_in_bytes,
440                &available_devices,
441                dtype,
442                &params,
443                paged_attn_config.as_ref(),
444            )?;
445            mapper = DeviceMapSetting::Map(new);
446        }
447
448        let pipeline_mapper = mapper.into_mapper(
449            self.inner.num_layers(&config)?,
450            &device,
451            self.config.topology.as_ref(),
452        )?;
453        let mapper = mapper.into_mapper(
454            self.inner.num_layers(&config)?,
455            &device,
456            self.config.topology.as_ref(),
457        )?;
458        let mut layer_devices = Vec::new();
459        for layer in 0..self.inner.num_layers(&config)? {
460            let device = mapper.device_for(layer, false).cloned();
461            layer_devices.push(device);
462        }
463        let dtype = mapper.get_min_dtype(dtype)?;
464
465        // TODO: PagedAttention is not supported with CPU for now.
466        // This check is not really necessary because `get_device_layers` should prevent it.
467        let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
468        if mapping_uses_cpu && paged_attn_config.is_some() {
469            warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
470            paged_attn_config = None;
471        }
472
473        info!("Model config: {:?}", self.inner.get_config_repr(&config)?);
474        if crate::using_flash_attn() {
475            once_log_info("FlashAttention is enabled.");
476        }
477
478        let topology_overrides = self
479            .config
480            .topology
481            .as_ref()
482            .map(|topology| {
483                topology
484                    .pattern_overrides()
485                    .into_iter()
486                    .map(|(regex, layer)| ImmediateIsqOverride {
487                        predicate: regex,
488                        ty: layer.isq,
489                        device: layer.device.clone(),
490                    })
491                    .collect::<Vec<_>>()
492            })
493            .unwrap_or_default();
494        let has_override_isq = topology_overrides
495            .iter()
496            .any(|override_entry| override_entry.ty.is_some());
497        let topology_requires_post_quant = self
498            .config
499            .topology
500            .as_ref()
501            .is_some_and(|topology| topology.requires_post_quantization());
502
503        let allow_immediate_cli = self.config.imatrix.is_none()
504            && self.config.calibration_file.is_none()
505            && !device.is_cuda()
506            && in_situ_quant.is_some();
507
508        let mut immediate_ty = None;
509        let mut immediate_predicates = Vec::new();
510        if allow_immediate_cli {
511            immediate_ty = in_situ_quant;
512            immediate_predicates = self.inner.immediate_isq_predicates(&config)?;
513            info!("Applying ISQ to {in_situ_quant:?}");
514            if immediate_predicates.is_empty() {
515                warn!("No predicates for this model and ISQ setting detected. ISQ will not be applied to any weights!");
516            }
517        }
518
519        let use_immediate = allow_immediate_cli || has_override_isq;
520        if use_immediate {
521            mistralrs_quant::set_immediate_isq_with_overrides(
522                immediate_ty,
523                immediate_predicates.clone(),
524                topology_overrides.clone(),
525            );
526        }
527
528        // Logic for ISQ here: if no calibration (i.e imatrix), then allow immediate ISQ. Otherwise, back to normal.
529        let mut loading_isq = if use_immediate {
530            false
531        } else {
532            in_situ_quant.is_some()
533        };
534        if self.config.imatrix.is_some() || self.config.calibration_file.is_some() {
535            loading_isq = true;
536        }
537        loading_isq |= topology_requires_post_quant;
538
539        if self.config.imatrix.is_some() && self.config.calibration_file.is_some() {
540            anyhow::bail!(
541                "`imatrix` and `calibration_file` were both specified, this is not allowed."
542            );
543        }
544
545        // Load onto the regular device if not using isq or if the calibration file is specified
546        let load_device = if !loading_isq || self.config.calibration_file.is_some() {
547            loading_isq = false;
548            device.clone()
549        } else {
550            Device::Cpu
551        };
552
553        let attention_mechanism = if paged_attn_config.is_some() {
554            AttentionImplementation::PagedAttention
555        } else {
556            AttentionImplementation::Eager
557        };
558
559        let multi_progress = Arc::new(new_multi_progress());
560
561        let mut model = if use_nccl {
562            let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
563                dtype,
564                &device,
565                &available_devices,
566                silent,
567                &config,
568                loading_isq,
569                self.config.from_uqff.is_some(),
570                IsqOrganization::Default,
571                &*self.inner,
572                paths.as_ref(),
573            )?;
574
575            // Special case for where things can be more optimially loaded.
576            match self.kind {
577                ModelKind::Normal => vision_normal_model_loader_sharded!(
578                    sharded_vb,
579                    config,
580                    self.inner,
581                    mapper,
582                    loading_isq,
583                    device.clone(),
584                    attention_mechanism,
585                    multi_progress.clone(),
586                    matformer_slicing_config.clone(),
587                ),
588                _ => unreachable!(),
589            }
590        } else {
591            match self.kind {
592                ModelKind::Normal => vision_normal_model_loader!(
593                    paths,
594                    Some(dtype),
595                    &load_device,
596                    layer_devices.clone(),
597                    config,
598                    self.inner,
599                    silent,
600                    mapper,
601                    loading_isq,
602                    self.config.from_uqff.is_some(),
603                    device.clone(),
604                    attention_mechanism,
605                    multi_progress,
606                    matformer_slicing_config.clone(),
607                ),
608                _ => unreachable!(),
609            }
610        };
611
612        // Handle the Gemma 3 1b case here
613        let preprocessor_config: PreProcessorConfig = match paths.get_preprocessor_config().as_ref()
614        {
615            Some(preprocessor_config) => {
616                serde_json::from_str(&fs::read_to_string(preprocessor_config).unwrap()).unwrap()
617            }
618            None => PreProcessorConfig::default(),
619        };
620        let processor_config: Option<ProcessorConfig> = paths
621            .get_processor_config()
622            .as_ref()
623            .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
624
625        let processor = self.inner.get_processor(
626            &config,
627            processor_config,
628            preprocessor_config.clone(),
629            self.config.max_edge,
630        ); //There are always some repos that don't properly handle config position, for example... LLaVA
631
632        let tokenizer = get_tokenizer(
633            paths.get_tokenizer_filename(),
634            Some(processor.get_special_tokens()),
635        )?;
636
637        let gen_conf: Option<GenerationConfig> = paths
638            .get_gen_conf_filename()
639            .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
640        let chat_template_explicit = paths
641            .get_chat_template_explicit()
642            .as_ref()
643            .map(|x| x.to_string_lossy().to_string());
644        let chat_template = get_chat_template(
645            paths,
646            self.jinja_explicit.as_ref(),
647            chat_template_explicit.as_ref(),
648            self.chat_template.as_ref(),
649            None,
650        );
651
652        if let Some(calibration_file) = &self.config.calibration_file {
653            let calibration_data = std::fs::read_to_string(calibration_file)?;
654            // Tokenize, don't add bos yet
655            let tokens = tokenizer
656                .encode_fast(calibration_data, false)
657                .map_err(anyhow::Error::msg)?
658                .get_ids()
659                .to_vec();
660            info!(
661                "Collecting imatrix from calibration file `{}` of {} tokens.",
662                calibration_file.display(),
663                tokens.len()
664            );
665            let bos_toks = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();
666            let bos_tok_id = tokenizer
667                .token_to_id(&bos_toks[0])
668                .expect("Somehow the bos token is not present.");
669
670            // NOTE: We ONLY calibrate the text bits of these models!!
671            // So only those should be tracked!
672            model.begin_track_stats()?;
673
674            const CHUNK_SIZE: usize = 1024;
675            let n_chunks: usize = tokens.len().div_ceil(CHUNK_SIZE);
676            let start = Instant::now();
677            for (i, chunk) in tokens.chunks(CHUNK_SIZE).enumerate() {
678                let chunk = [vec![bos_tok_id], chunk.to_vec()].concat();
679                let chunk_len = chunk.len();
680
681                let start = Instant::now();
682                let inputs = make_prompt_chunk(
683                    0,
684                    vec![&chunk],
685                    &[0],
686                    &load_device,
687                    None,
688                    false,
689                    None,
690                    None,
691                )?;
692                let _ = model.forward(
693                    &inputs.input,
694                    None, // NOTE: We ONLY calibrate the text bits of these models!!
695                    &inputs.positions,
696                    inputs.context_lens,
697                    inputs.position_ids,
698                    model.default_model_specific_args(&inputs.input),
699                    None,
700                    &inputs.flash_meta,
701                )?;
702                match model.cache_mut() {
703                    EitherCache::Full(full) => {
704                        for layer in &mut *full.lock() {
705                            *layer = None
706                        }
707                    }
708                    EitherCache::Normal(normal) => {
709                        for layer in &mut *normal.lock().unwrap().0 {
710                            layer.reset();
711                        }
712                    }
713                }
714                let end = Instant::now();
715                info!(
716                    "Processed chunk {}/{n_chunks} ({chunk_len} tokens), {:.2}s",
717                    i + 1,
718                    end.duration_since(start).as_secs_f32()
719                );
720            }
721            load_device.synchronize()?;
722            let end = Instant::now();
723            info!(
724                "Finished collecting imatrix in {:.2}s",
725                end.duration_since(start).as_secs_f32()
726            );
727        }
728
729        let should_serialize = self.config.write_uqff.is_some();
730        let should_quantize_pass = loading_isq;
731
732        if (should_quantize_pass || should_serialize) && self.config.from_uqff.is_none() {
733            let imatrix_source = if should_quantize_pass {
734                match (
735                    self.config.imatrix.as_ref(),
736                    self.config.calibration_file.is_some(),
737                ) {
738                    (None, false) => None,
739                    (Some(file), false) => Some(ImatrixDataSource::File(file)),
740                    (None, true) => Some(ImatrixDataSource::Collected),
741                    (Some(_), true) => unreachable!(),
742                }
743            } else {
744                None
745            };
746            if should_quantize_pass {
747                info!("Applying ISQ to all ranks.");
748            } else {
749                info!("Serializing existing ISQ tensors without additional quantization.");
750            }
751            model.quantize(
752                in_situ_quant,
753                device.clone(),
754                self.config.topology.as_ref(),
755                silent,
756                imatrix_source,
757                IsqOrganization::Default,
758                should_quantize_pass,
759                self.config.write_uqff.as_ref(),
760                UqffFullSer {
761                    tokenizer: &tokenizer,
762                    template_filename: paths.get_template_filename(),
763                    generation_config: paths.get_gen_conf_filename(),
764                    config: config.clone(),
765                    processor_filename: paths.get_processor_config(),
766                    preprocessor_filename: paths.get_preprocessor_config(),
767                    modules: None,
768                    module_paths: None,
769                },
770                Arc::new(new_multi_progress()),
771            )?;
772        } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
773            model.load_from_artifacts(
774                device.clone(),
775                self.config.topology.as_ref(),
776                silent,
777                from_uqff,
778            )?;
779        }
780
781        let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
782            anyhow::ensure!(
783                !matches!(self.kind, ModelKind::Adapter { .. }),
784                "PagedAttention does not support adapter models."
785            );
786            let cache_config = calculate_cache_config(
787                paged_attn_config.mem_gpu,
788                paged_attn_config.block_size,
789                dtype,
790                paged_attn_config.cache_type,
791                model.config(),
792                &device,
793                &layer_devices,
794                silent,
795            )?;
796            let cache_engine =
797                CacheEngine::new(model.config(), &cache_config, dtype, &device, layer_devices)?;
798            (Some(cache_config), Some(cache_engine))
799        } else {
800            (None, None)
801        };
802
803        let max_seq_len = model.max_seq_len();
804        let llg_factory = build_llg_factory(tokenizer.clone())?;
805        let num_hidden_layers = match model.cache() {
806            EitherCache::Full(full) => full.lock().len(),
807            EitherCache::Normal(normal) => normal.lock().unwrap().0.len(),
808        };
809        let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
810        let sliding_window = model.config().sliding_window;
811        let model_metadata = Arc::new(model.config().clone());
812        Ok(Arc::new(Mutex::new(VisionPipeline {
813            model,
814            tokenizer: tokenizer.into(),
815            chat_template: Arc::new(chat_template),
816            model_id: self.model_id.clone(),
817            metadata: Arc::new(GeneralMetadata {
818                max_seq_len,
819                llg_factory: Some(llg_factory),
820                is_xlora: false,
821                num_hidden_layers,
822                eos_tok: eos,
823                kind: self.kind.clone(),
824                no_kv_cache: false,
825                no_prefix_cache: !self.inner.supports_prefix_cacher(&config),
826                activation_dtype: dtype,
827                sliding_window,
828                cache_config,
829                cache_engine,
830                model_metadata: Some(model_metadata),
831                modalities: self.inner.modalities(&config)?,
832            }),
833            processor,
834            prefixer: self.inner.prefixer(&config),
835            preprocessor_config: Arc::new(preprocessor_config),
836            topology: self.config.topology.clone(),
837            silent,
838            template_filename: paths.get_template_filename().clone(),
839            generation_config: paths.get_gen_conf_filename().cloned(),
840            config,
841            processor_filename: paths.get_processor_config().clone(),
842            preprocessor_filename: paths.get_preprocessor_config().clone(),
843            mapper: pipeline_mapper,
844            imatrix: self.config.imatrix.clone(),
845        })))
846    }
847
848    fn get_id(&self) -> String {
849        self.model_id.to_string()
850    }
851
852    fn get_kind(&self) -> ModelKind {
853        self.kind.clone()
854    }
855}
856
857impl PreProcessingMixin for VisionPipeline {
858    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
859        Some(self.chat_template.clone())
860    }
861    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
862        Some(self.preprocessor_config.clone())
863    }
864    fn get_processor(&self) -> Arc<dyn super::Processor> {
865        self.processor.clone()
866    }
867}
868
869impl IsqPipelineMixin for VisionPipeline {
870    fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
871        let device = self.device().clone();
872        self.model
873            .quantize(
874                Some(dtype),
875                device,
876                self.topology.as_ref(),
877                self.silent,
878                self.imatrix.as_ref().map(ImatrixDataSource::File),
879                IsqOrganization::Default,
880                true,
881                None,
882                UqffFullSer {
883                    tokenizer: &self.tokenizer,
884                    template_filename: &self.template_filename,
885                    generation_config: self.generation_config.as_ref(),
886                    config: self.config.clone(),
887                    processor_filename: &self.processor_filename,
888                    preprocessor_filename: &self.preprocessor_filename,
889                    modules: None,
890                    module_paths: None,
891                },
892                Arc::new(new_multi_progress()),
893            )
894            .map_err(anyhow::Error::msg)
895    }
896}
897
898impl CacheManagerMixin for VisionPipeline {
899    fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
900        if matches!(self.model.cache(), EitherCache::Full(_)) {
901            FullCacheManager.clone_in_cache(self, seqs, false)
902        } else {
903            NormalCacheManager.clone_in_cache(self, seqs, false)
904        }
905    }
906    fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
907        if matches!(self.model.cache(), EitherCache::Full(_)) {
908            FullCacheManager.clone_out_cache(self, seqs, false)
909        } else {
910            NormalCacheManager.clone_out_cache(self, seqs, false)
911        }
912    }
913    fn set_none_cache(
914        &self,
915        seqs: &mut [&mut Sequence],
916        reset_non_granular: bool,
917        modify_draft_cache: bool,
918
919        load_preallocated_cache: bool,
920    ) {
921        if matches!(self.model.cache(), EitherCache::Full(_)) {
922            FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false);
923        } else {
924            NormalCacheManager.set_none_cache(
925                self,
926                seqs,
927                modify_draft_cache,
928                load_preallocated_cache,
929            );
930        }
931        if reset_non_granular {
932            self.reset_non_granular_state()
933        }
934    }
935    fn cache(&self) -> &EitherCache {
936        self.model.cache()
937    }
938}
939
940impl MetadataMixin for VisionPipeline {
941    fn device(&self) -> Device {
942        self.model.device().clone()
943    }
944    fn get_metadata(&self) -> Arc<GeneralMetadata> {
945        self.metadata.clone()
946    }
947    fn name(&self) -> String {
948        self.model_id.clone()
949    }
950    fn reset_non_granular_state(&self) {}
951    fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
952        Some(self.tokenizer.clone())
953    }
954    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
955        Some(&*self.mapper)
956    }
957}
958
959#[async_trait::async_trait]
960impl Pipeline for VisionPipeline {
961    fn forward_inputs(
962        &mut self,
963        inputs: Box<dyn Any>,
964        return_raw_logits: bool,
965    ) -> candle_core::Result<ForwardInputsResult> {
966        let ModelInputs {
967            input_ids,
968            seqlen_offsets,
969            context_lens,
970            position_ids,
971            pixel_values,
972            model_specific_args,
973            paged_attn_meta,
974            flash_meta,
975        } = *inputs.downcast::<ModelInputs>().expect("Downcast failed.");
976        let metadata = self.get_metadata();
977        let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
978            (Some(engine), Some(meta)) => Some((engine.get_kv_cache().clone(), meta)),
979            (Some(_), None) => {
980                // This can happen if Rust-side user code is wrong
981                candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
982            }
983            (None, Some(_)) => {
984                // This should never happen but we handle it anyway
985                candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
986            }
987            (None, None) => None,
988        };
989        let logits = self.model.forward(
990            &input_ids,
991            pixel_values,
992            &seqlen_offsets,
993            context_lens,
994            position_ids,
995            model_specific_args,
996            paged_attn_meta,
997            &flash_meta,
998        )?;
999        if return_raw_logits {
1000            Ok(ForwardInputsResult::RawLogits { logits })
1001        } else {
1002            Ok(ForwardInputsResult::CausalGeneration { logits })
1003        }
1004    }
1005    async fn sample_causal_gen(
1006        &self,
1007        seqs: &mut [&mut Sequence],
1008        logits: Vec<Tensor>,
1009        prefix_cacher: &mut PrefixCacheManagerV2,
1010        disable_eos_stop: bool,
1011        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
1012    ) -> Result<(), candle_core::Error> {
1013        sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
1014    }
1015    fn category(&self) -> ModelCategory {
1016        ModelCategory::Vision {
1017            prefixer: self.prefixer.clone(),
1018        }
1019    }
1020}
1021
1022impl AnyMoePipelineMixin for VisionPipeline {
1023    fn amoe_finish_training(&mut self, gate_model_id: Option<String>) -> candle_core::Result<()> {
1024        self.model.finish_training(gate_model_id)
1025    }
1026    fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
1027        self.model.get_vars()
1028    }
1029    fn amoe_base_model_trainable_params(&self) -> usize {
1030        self.model.trainable_params()
1031    }
1032    fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
1033        self.model.take_cached_gating_outputs()
1034    }
1035    fn amoe_create_layers(
1036        &mut self,
1037        model_ids: Vec<String>,
1038        token: &TokenSource,
1039        revision: Option<String>,
1040        match_regex: &str,
1041        config: crate::amoe::AnyMoeConfig,
1042        dtype: candle_core::DType,
1043        dev: &Device,
1044        (prefix, mlp): (String, String),
1045        layers: Vec<usize>,
1046        expert_type: AnyMoeExpertType,
1047        silent: bool,
1048        gate_model_id: Option<String>,
1049    ) -> candle_core::Result<()> {
1050        let mut vbs = Vec::new();
1051        // Precompile regex here
1052        let regex = Regex::new(match_regex).map_err(candle_core::Error::msg)?;
1053        for model_id in model_ids {
1054            let model_id_str = &model_id;
1055            let model_id = Path::new(&model_id);
1056
1057            let api = {
1058                let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1059                let mut api = ApiBuilder::from_cache(cache)
1060                    .with_progress(!silent)
1061                    .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1062                if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1063                    api = api.with_cache_dir(x.into());
1064                }
1065                api.build().map_err(candle_core::Error::msg)?
1066            };
1067            let revision = revision.clone().unwrap_or("main".to_string());
1068            let api = api.repo(Repo::with_revision(
1069                model_id_str.clone(),
1070                RepoType::Model,
1071                revision.clone(),
1072            ));
1073
1074            let mut filenames = vec![];
1075            for rfilename in
1076                api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1077            {
1078                filenames.push(api_get_file!(api, &rfilename, model_id));
1079            }
1080
1081            let regex = regex.clone();
1082            let match_regex_clone = match_regex.to_string();
1083            let layers_clone = layers.clone();
1084            let vb = from_mmaped_safetensors(
1085                filenames,
1086                vec![],
1087                Some(dtype),
1088                dev,
1089                vec![None],
1090                silent,
1091                None,
1092                move |key| {
1093                    if regex.is_match(&key) {
1094                        // Idx of the last char of the layer id, +1
1095                        // Assumes N.MLP
1096                        let last_layer_idx = key.find(&match_regex_clone).unwrap() - 1;
1097                        let first_layer_idx = key[..last_layer_idx].rfind('.').unwrap();
1098                        let layer_n = key[first_layer_idx + 1..last_layer_idx]
1099                            .parse::<usize>()
1100                            .unwrap();
1101                        layers_clone.contains(&layer_n) || layers_clone.is_empty()
1102                    } else {
1103                        false
1104                    }
1105                },
1106                Arc::new(|_| DeviceForLoadTensor::Base),
1107            )?;
1108            vbs.push(vb);
1109        }
1110
1111        let gate_vb = if let Some(gate_model_id) = gate_model_id {
1112            let model_id_str = &gate_model_id;
1113            let model_id = Path::new(&gate_model_id);
1114
1115            let api = {
1116                let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1117                let mut api = ApiBuilder::from_cache(cache)
1118                    .with_progress(!silent)
1119                    .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1120                if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1121                    api = api.with_cache_dir(x.into());
1122                }
1123                api.build().map_err(candle_core::Error::msg)?
1124            };
1125            let revision = revision.clone().unwrap_or("main".to_string());
1126            let api = api.repo(Repo::with_revision(
1127                model_id_str.clone(),
1128                RepoType::Model,
1129                revision.clone(),
1130            ));
1131
1132            let mut gate_filenames = vec![];
1133            for rfilename in
1134                api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1135            {
1136                gate_filenames.push(api_get_file!(api, &rfilename, model_id));
1137            }
1138            assert_eq!(
1139                gate_filenames.len(),
1140                1,
1141                "Gate model ID must contain only one .safetensors file"
1142            );
1143
1144            let vb = from_mmaped_safetensors(
1145                gate_filenames.clone(),
1146                vec![],
1147                Some(dtype),
1148                dev,
1149                vec![None],
1150                silent,
1151                None,
1152                |_| true,
1153                Arc::new(|_| DeviceForLoadTensor::Base),
1154            )?;
1155            info!(
1156                "Loaded gating layers from `{}`",
1157                gate_filenames[0].display()
1158            );
1159            Some(vb)
1160        } else {
1161            None
1162        };
1163
1164        self.model
1165            .create_anymoe_layers(vbs, config, (prefix, mlp), layers, expert_type, gate_vb)
1166    }
1167    fn amoe_supported(&self) -> bool {
1168        self.model.amoe_supported()
1169    }
1170}