mistralrs_core/pipeline/
vision.rs

1use super::isq::ImatrixDataSource;
2use super::isq::UqffFullSer;
3use super::{
4    get_model_paths, get_xlora_paths, AdapterKind, AnyMoePipelineMixin, AutoVisionLoader,
5    CacheManager, CacheManagerMixin, EitherCache, ForwardInputsResult, Gemma3Loader,
6    GeneralMetadata, IsqPipelineMixin, Loader, MetadataMixin, MiniCpmOLoader, ModelCategory,
7    ModelKind, ModelPaths, MultimodalPromptPrefixer, Phi4MMLoader, PreProcessingMixin, Processor,
8    Qwen2VLLoader, Qwen3VLLoader, Qwen3VLMoELoader, TokenSource, VLlama4Loader, VLlamaLoader,
9    VisionModel, VisionModelLoader,
10};
11use super::{
12    Gemma3nLoader, Idefics2Loader, Idefics3Loader, LLaVALoader, LLaVANextLoader, Mistral3Loader,
13    Phi3VLoader, Qwen2_5VLLoader, VisionLoaderType,
14};
15use crate::attention::ATTENTION_CHUNK_SIZE;
16use crate::device_map::{self, DeviceMapper};
17use crate::distributed::{self, WorkerTransferData};
18use crate::kv_cache::{FullCacheManager, NormalCacheManager};
19use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine};
20use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
21use crate::pipeline::llg::build_llg_factory;
22use crate::pipeline::loaders::auto_device_map;
23use crate::pipeline::loaders::QuantizationConfigShim;
24use crate::pipeline::sampling::sample_and_add_toks;
25use crate::pipeline::text_models_inputs_processor::make_prompt_chunk;
26use crate::pipeline::{get_chat_template, ChatTemplate, IsqOrganization, LocalModelPaths};
27use crate::prefix_cacher::PrefixCacheManagerV2;
28use crate::sequence::Sequence;
29use crate::utils::tokenizer::get_tokenizer;
30use crate::utils::varbuilder_utils::DeviceForLoadTensor;
31use crate::utils::{
32    progress::{new_multi_progress, ProgressScopeGuard},
33    tokens::get_token,
34    varbuilder_utils::from_mmaped_safetensors,
35};
36use crate::vision_models::preprocessor_config::PreProcessorConfig;
37use crate::vision_models::processor_config::ProcessorConfig;
38use crate::vision_models::ModelInputs;
39use crate::{
40    api_dir_list, api_get_file, get_paths, get_uqff_paths, vision_normal_model_loader,
41    vision_normal_model_loader_sharded, AnyMoeExpertType, DeviceMapSetting, Ordering,
42    PagedAttentionConfig, Pipeline, Topology, TryIntoDType, GLOBAL_HF_CACHE,
43};
44use anyhow::Result;
45use candle_core::{Device, Tensor, Var};
46use hf_hub::Cache;
47use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
48use mistralrs_quant::log::once_log_info;
49use mistralrs_quant::{
50    AfqLayer, GgufMatMul, HqqLayer, ImmediateIsqOverride, IsqType, QuantizedSerdeType,
51};
52use rand_isaac::Isaac64Rng;
53use regex_automata::meta::Regex;
54use std::any::Any;
55use std::borrow::Cow;
56use std::path::{Path, PathBuf};
57use std::str::FromStr;
58use std::sync::{Arc, RwLock};
59use std::time::Instant;
60use std::{env, fs};
61use tokenizers::Tokenizer;
62use tokio::sync::Mutex;
63use tracing::{info, warn};
64
65pub struct VisionPipeline {
66    model: Box<dyn VisionModel + Send + Sync>,
67    tokenizer: Arc<Tokenizer>,
68    chat_template: Arc<ChatTemplate>,
69    model_id: String,
70    metadata: Arc<GeneralMetadata>,
71    processor: Arc<dyn Processor + Send + Sync>,
72    preprocessor_config: Arc<PreProcessorConfig>,
73    topology: Option<Topology>,
74    silent: bool,
75    prefixer: Arc<dyn MultimodalPromptPrefixer>,
76    mapper: Box<dyn DeviceMapper + Send + Sync>,
77
78    // For full UQFF serialization
79    template_filename: Option<PathBuf>,
80    generation_config: Option<PathBuf>,
81    config: String,
82    processor_filename: Option<PathBuf>,
83    preprocessor_filename: Option<PathBuf>,
84    imatrix: Option<PathBuf>,
85}
86
87/// A loader for a vision (non-quantized) model.
88pub struct VisionLoader {
89    inner: Box<dyn VisionModelLoader>,
90    model_id: String,
91    config: VisionSpecificConfig,
92    kind: ModelKind,
93    chat_template: Option<String>,
94    tokenizer_json: Option<String>,
95    xlora_model_id: Option<String>,
96    xlora_order: Option<Ordering>,
97    token_source: RwLock<Option<TokenSource>>,
98    revision: RwLock<Option<String>>,
99    from_uqff: RwLock<Option<Vec<PathBuf>>>,
100    jinja_explicit: Option<String>,
101    hf_cache_path: Option<PathBuf>,
102    lora_adapter_ids: Option<Vec<String>>,
103}
104
105#[derive(Default)]
106/// A builder for a loader for a vision (non-quantized) model.
107pub struct VisionLoaderBuilder {
108    model_id: Option<String>,
109    config: VisionSpecificConfig,
110    kind: ModelKind,
111    chat_template: Option<String>,
112    tokenizer_json: Option<String>,
113    jinja_explicit: Option<String>,
114    hf_cache_path: Option<PathBuf>,
115    lora_adapter_ids: Option<Vec<String>>,
116}
117
118#[derive(Clone, Default)]
119/// Config specific to loading a vision model.
120pub struct VisionSpecificConfig {
121    pub topology: Option<Topology>,
122    pub write_uqff: Option<PathBuf>,
123    pub from_uqff: Option<Vec<PathBuf>>,
124    pub max_edge: Option<u32>,
125    pub imatrix: Option<PathBuf>,
126    pub calibration_file: Option<PathBuf>,
127    pub hf_cache_path: Option<PathBuf>,
128    pub matformer_config_path: Option<PathBuf>,
129    pub matformer_slice_name: Option<String>,
130}
131
132impl VisionLoaderBuilder {
133    pub fn new(
134        config: VisionSpecificConfig,
135        chat_template: Option<String>,
136        tokenizer_json: Option<String>,
137        model_id: Option<String>,
138        jinja_explicit: Option<String>,
139    ) -> Self {
140        Self {
141            config,
142            chat_template,
143            tokenizer_json,
144            model_id,
145            jinja_explicit,
146            kind: ModelKind::Normal,
147            hf_cache_path: None,
148            ..Default::default()
149        }
150    }
151
152    pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
153        self.hf_cache_path = Some(hf_cache_path);
154        self
155    }
156
157    pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
158        self.kind = ModelKind::Adapter {
159            adapter: AdapterKind::Lora,
160        };
161        self.lora_adapter_ids = Some(lora_adapter_ids);
162        self
163    }
164
165    pub fn build(self, loader: Option<VisionLoaderType>) -> Box<dyn Loader> {
166        let loader: Box<dyn VisionModelLoader> = match loader {
167            Some(VisionLoaderType::Phi3V) => Box::new(Phi3VLoader),
168            Some(VisionLoaderType::Idefics2) => Box::new(Idefics2Loader),
169            Some(VisionLoaderType::LLaVANext) => Box::new(LLaVANextLoader),
170            Some(VisionLoaderType::LLaVA) => Box::new(LLaVALoader),
171            Some(VisionLoaderType::VLlama) => Box::new(VLlamaLoader),
172            Some(VisionLoaderType::Qwen2VL) => Box::new(Qwen2VLLoader),
173            Some(VisionLoaderType::Idefics3) => Box::new(Idefics3Loader),
174            Some(VisionLoaderType::MiniCpmO) => Box::new(MiniCpmOLoader),
175            Some(VisionLoaderType::Phi4MM) => Box::new(Phi4MMLoader),
176            Some(VisionLoaderType::Qwen2_5VL) => Box::new(Qwen2_5VLLoader),
177            Some(VisionLoaderType::Gemma3) => Box::new(Gemma3Loader),
178            Some(VisionLoaderType::Mistral3) => Box::new(Mistral3Loader),
179            Some(VisionLoaderType::Llama4) => Box::new(VLlama4Loader),
180            Some(VisionLoaderType::Gemma3n) => Box::new(Gemma3nLoader),
181            Some(VisionLoaderType::Qwen3VL) => Box::new(Qwen3VLLoader),
182            Some(VisionLoaderType::Qwen3VLMoE) => Box::new(Qwen3VLMoELoader),
183            None => Box::new(AutoVisionLoader),
184        };
185        Box::new(VisionLoader {
186            inner: loader,
187            model_id: self.model_id.unwrap(),
188            config: self.config,
189            kind: self.kind,
190            chat_template: self.chat_template,
191            tokenizer_json: self.tokenizer_json,
192            xlora_model_id: None,
193            xlora_order: None,
194            jinja_explicit: self.jinja_explicit,
195            token_source: RwLock::new(None),
196            revision: RwLock::new(None),
197            from_uqff: RwLock::new(None),
198            hf_cache_path: self.hf_cache_path,
199            lora_adapter_ids: self.lora_adapter_ids,
200        })
201    }
202}
203
204impl Loader for VisionLoader {
205    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
206    fn load_model_from_hf(
207        &self,
208        revision: Option<String>,
209        token_source: TokenSource,
210        dtype: &dyn TryIntoDType,
211        device: &Device,
212        silent: bool,
213        mapper: DeviceMapSetting,
214        in_situ_quant: Option<IsqType>,
215        paged_attn_config: Option<PagedAttentionConfig>,
216    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
217        let _progress_guard = ProgressScopeGuard::new(silent);
218        let cache = self
219            .hf_cache_path
220            .clone()
221            .map(Cache::new)
222            .unwrap_or_default();
223        GLOBAL_HF_CACHE.get_or_init(|| cache);
224
225        let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths!(
226            LocalModelPaths,
227            &token_source,
228            revision.clone(),
229            self,
230            None,
231            None,
232            silent,
233            self.config.from_uqff.is_some()
234        );
235        if let Some(from_uqff) = self.config.from_uqff.clone() {
236            *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
237        }
238        *self
239            .token_source
240            .write()
241            .expect("Failed to write to token source") = Some(token_source);
242        *self.revision.write().expect("Failed to write to revision") = revision;
243        self.load_model_from_path(
244            &paths?,
245            dtype,
246            device,
247            silent,
248            mapper,
249            in_situ_quant,
250            paged_attn_config,
251        )
252    }
253
254    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
255    fn load_model_from_path(
256        &self,
257        paths: &Box<dyn ModelPaths>,
258        dtype: &dyn TryIntoDType,
259        device: &Device,
260        silent: bool,
261        mut mapper: DeviceMapSetting,
262        mut in_situ_quant: Option<IsqType>,
263        mut paged_attn_config: Option<PagedAttentionConfig>,
264    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
265        let _progress_guard = ProgressScopeGuard::new(silent);
266        let config = std::fs::read_to_string(paths.get_config_filename())?;
267
268        if !self.inner.supports_paged_attention(&config) {
269            paged_attn_config = None;
270        }
271
272        info!("Prompt chunk size is {ATTENTION_CHUNK_SIZE}.");
273
274        let use_nccl = mistralrs_quant::distributed::use_nccl();
275
276        let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
277            let payload: WorkerTransferData = serde_json::from_str(&payload)?;
278            let WorkerTransferData::Init { id: _, worker_rank } = payload;
279            vec![candle_core::Device::new_cuda_with_stream(worker_rank + 1)?]
280        } else if use_nccl {
281            vec![candle_core::Device::new_cuda_with_stream(0)?]
282        } else {
283            device_map::get_all_similar_devices(device)?
284        };
285        #[cfg(feature = "cuda")]
286        for device in &available_devices {
287            if let Device::Cuda(dev) = device {
288                unsafe { dev.disable_event_tracking() };
289            }
290        }
291        let device = if use_nccl {
292            available_devices[0].clone()
293        } else {
294            device.clone()
295        };
296
297        // Load matformer slicing config if provided
298        let matformer_slicing_config = if let Some(matformer_path) =
299            &self.config.matformer_config_path
300        {
301            use crate::matformer::{MatformerConfig, MatformerSliceConfig};
302            info!("Loading Matformer config from {:?}", matformer_path);
303            let config = Arc::new(MatformerConfig::from_file(matformer_path)?);
304
305            if let Some(slice_name) = &self.config.matformer_slice_name {
306                info!("Using Matformer slice: {}", slice_name);
307                Some(MatformerSliceConfig::new(slice_name.clone(), config))
308            } else {
309                // If no slice name is provided but config exists, we'll need to handle this
310                // For now, return None and let the model handle the default slice selection
311                warn!("Matformer config loaded but no slice name specified. Models will use their default slice.");
312                None
313            }
314        } else {
315            None
316        };
317
318        // If auto, convert to Map if not using nccl
319        if use_nccl {
320            mapper = DeviceMapSetting::DummyNccl {
321                nm_device: available_devices[0].clone(),
322            };
323        } else if let DeviceMapSetting::Auto(mut params) = mapper.clone() {
324            // We can promote to vision params if we get text params
325            params = params.maybe_promote_to_vision();
326
327            // Initial dtype
328            let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
329
330            // Disable ISQ if we are loading a prequantized model.
331            if QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)? != 1 {
332                in_situ_quant = None;
333            }
334
335            // ISQ or UQFF: quantized path
336            // Match logic below where UQFF has priority
337            let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
338                if let Some(serialized) = &*self.from_uqff.read().unwrap() {
339                    let weight_pack_factor = {
340                        let ser_artifacts = unsafe {
341                            candle_core::safetensors::MmapedSafetensors::multi(serialized)?
342                        };
343                        let mut total_pack_factors = 0;
344                        let total_tensors = ser_artifacts.tensors().len();
345                        for (_, artifact) in ser_artifacts.tensors() {
346                            let artifact = artifact.data();
347                            // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
348                            let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
349                            let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
350                            {
351                                QuantizedSerdeType::Hqq => {
352                                    HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
353                                        .pack_factor(dtype)
354                                }
355                                QuantizedSerdeType::Gguf => {
356                                    GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
357                                        .pack_factor(dtype)
358                                }
359                                QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
360                                QuantizedSerdeType::Unquant => 1,
361                                QuantizedSerdeType::Afq => {
362                                    AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
363                                        .pack_factor(dtype)
364                                }
365                            };
366                            total_pack_factors += pack_factor;
367                        }
368
369                        total_pack_factors / total_tensors
370                    };
371
372                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
373                        &config,
374                        dtype,
375                        weight_pack_factor,
376                        matformer_slicing_config.as_ref(),
377                    )?;
378                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
379                        &config,
380                        dtype,
381                        weight_pack_factor,
382                        matformer_slicing_config.as_ref(),
383                    )?;
384                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
385                    (
386                        layer_sizes_in_bytes,
387                        non_mapped_size_in_bytes,
388                        layer_sizes_sum + non_mapped_size_in_bytes,
389                    )
390                } else if let Some(isq) = in_situ_quant {
391                    let weight_pack_factor = isq.pack_factor(dtype);
392                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
393                        &config,
394                        dtype,
395                        weight_pack_factor,
396                        matformer_slicing_config.as_ref(),
397                    )?;
398                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
399                        &config,
400                        dtype,
401                        weight_pack_factor,
402                        matformer_slicing_config.as_ref(),
403                    )?;
404                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
405                    (
406                        layer_sizes_in_bytes,
407                        non_mapped_size_in_bytes,
408                        layer_sizes_sum + non_mapped_size_in_bytes,
409                    )
410                } else {
411                    // Be sure to get the weight pack factor here; we might be loading a prequantized model.
412                    let weight_pack_factor =
413                        QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)?;
414                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
415                        &config,
416                        dtype,
417                        weight_pack_factor,
418                        matformer_slicing_config.as_ref(),
419                    )?;
420                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
421                        &config,
422                        dtype,
423                        weight_pack_factor,
424                        matformer_slicing_config.as_ref(),
425                    )?;
426                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
427                    (
428                        layer_sizes_in_bytes,
429                        non_mapped_size_in_bytes,
430                        layer_sizes_sum + non_mapped_size_in_bytes,
431                    )
432                };
433
434            let new = auto_device_map::get_device_layers(
435                &*self.inner,
436                &config,
437                self.inner.num_layers(&config)?,
438                layer_sizes_in_bytes,
439                non_mapped_size_in_bytes,
440                total_model_size_in_bytes,
441                &available_devices,
442                dtype,
443                &params,
444                paged_attn_config.as_ref(),
445            )?;
446            mapper = DeviceMapSetting::Map(new);
447        }
448
449        let pipeline_mapper = mapper.into_mapper(
450            self.inner.num_layers(&config)?,
451            &device,
452            self.config.topology.as_ref(),
453        )?;
454        let mapper = mapper.into_mapper(
455            self.inner.num_layers(&config)?,
456            &device,
457            self.config.topology.as_ref(),
458        )?;
459        let mut layer_devices = Vec::new();
460        for layer in 0..self.inner.num_layers(&config)? {
461            let device = mapper.device_for(layer, false).cloned();
462            layer_devices.push(device);
463        }
464        let dtype = mapper.get_min_dtype(dtype)?;
465
466        // TODO: PagedAttention is not supported with CPU for now.
467        // This check is not really necessary because `get_device_layers` should prevent it.
468        let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
469        if mapping_uses_cpu && paged_attn_config.is_some() {
470            warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
471            paged_attn_config = None;
472        }
473
474        info!("Model config: {:?}", self.inner.get_config_repr(&config)?);
475        if crate::using_flash_attn() {
476            once_log_info("FlashAttention is enabled.");
477        }
478
479        let topology_overrides = self
480            .config
481            .topology
482            .as_ref()
483            .map(|topology| {
484                topology
485                    .pattern_overrides()
486                    .into_iter()
487                    .map(|(regex, layer)| ImmediateIsqOverride {
488                        predicate: regex,
489                        ty: layer.isq,
490                        device: layer.device.clone(),
491                    })
492                    .collect::<Vec<_>>()
493            })
494            .unwrap_or_default();
495        let has_override_isq = topology_overrides
496            .iter()
497            .any(|override_entry| override_entry.ty.is_some());
498        let topology_requires_post_quant = self
499            .config
500            .topology
501            .as_ref()
502            .is_some_and(|topology| topology.requires_post_quantization());
503
504        let allow_immediate_cli = self.config.imatrix.is_none()
505            && self.config.calibration_file.is_none()
506            && !device.is_cuda()
507            && in_situ_quant.is_some();
508
509        let mut immediate_ty = None;
510        let mut immediate_predicates = Vec::new();
511        if allow_immediate_cli {
512            immediate_ty = in_situ_quant;
513            immediate_predicates = self.inner.immediate_isq_predicates(&config)?;
514            info!("Applying ISQ to {in_situ_quant:?}");
515            if immediate_predicates.is_empty() {
516                warn!("No predicates for this model and ISQ setting detected. ISQ will not be applied to any weights!");
517            }
518        }
519
520        let use_immediate = allow_immediate_cli || has_override_isq;
521        if use_immediate {
522            mistralrs_quant::set_immediate_isq_with_overrides(
523                immediate_ty,
524                immediate_predicates.clone(),
525                topology_overrides.clone(),
526            );
527        }
528
529        // Logic for ISQ here: if no calibration (i.e imatrix), then allow immediate ISQ. Otherwise, back to normal.
530        let mut loading_isq = if use_immediate {
531            false
532        } else {
533            in_situ_quant.is_some()
534        };
535        if self.config.imatrix.is_some() || self.config.calibration_file.is_some() {
536            loading_isq = true;
537        }
538        loading_isq |= topology_requires_post_quant;
539
540        if self.config.imatrix.is_some() && self.config.calibration_file.is_some() {
541            anyhow::bail!(
542                "`imatrix` and `calibration_file` were both specified, this is not allowed."
543            );
544        }
545
546        // Load onto the regular device if not using isq or if the calibration file is specified
547        let load_device = if !loading_isq || self.config.calibration_file.is_some() {
548            loading_isq = false;
549            device.clone()
550        } else {
551            Device::Cpu
552        };
553
554        let attention_mechanism = if paged_attn_config.is_some() {
555            AttentionImplementation::PagedAttention
556        } else {
557            AttentionImplementation::Eager
558        };
559
560        let multi_progress = Arc::new(new_multi_progress());
561
562        let mut model = if use_nccl {
563            let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
564                dtype,
565                &device,
566                &available_devices,
567                silent,
568                &config,
569                loading_isq,
570                self.config.from_uqff.is_some(),
571                IsqOrganization::Default,
572                &*self.inner,
573                paths.as_ref(),
574            )?;
575
576            // Special case for where things can be more optimially loaded.
577            match self.kind {
578                ModelKind::Normal => vision_normal_model_loader_sharded!(
579                    sharded_vb,
580                    config,
581                    self.inner,
582                    mapper,
583                    loading_isq,
584                    device.clone(),
585                    attention_mechanism,
586                    multi_progress.clone(),
587                    matformer_slicing_config.clone(),
588                ),
589                _ => unreachable!(),
590            }
591        } else {
592            match self.kind {
593                ModelKind::Normal => vision_normal_model_loader!(
594                    paths,
595                    Some(dtype),
596                    &load_device,
597                    layer_devices.clone(),
598                    config,
599                    self.inner,
600                    silent,
601                    mapper,
602                    loading_isq,
603                    self.config.from_uqff.is_some(),
604                    device.clone(),
605                    attention_mechanism,
606                    multi_progress,
607                    matformer_slicing_config.clone(),
608                ),
609                _ => unreachable!(),
610            }
611        };
612
613        // Handle the Gemma 3 1b case here
614        let preprocessor_config: PreProcessorConfig = match paths.get_preprocessor_config().as_ref()
615        {
616            Some(preprocessor_config) => {
617                serde_json::from_str(&fs::read_to_string(preprocessor_config).unwrap()).unwrap()
618            }
619            None => PreProcessorConfig::default(),
620        };
621        let processor_config: Option<ProcessorConfig> = paths
622            .get_processor_config()
623            .as_ref()
624            .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
625
626        let processor = self.inner.get_processor(
627            &config,
628            processor_config,
629            preprocessor_config.clone(),
630            self.config.max_edge,
631        ); //There are always some repos that don't properly handle config position, for example... LLaVA
632
633        let tokenizer = get_tokenizer(
634            paths.get_tokenizer_filename(),
635            Some(processor.get_special_tokens()),
636        )?;
637
638        let gen_conf: Option<GenerationConfig> = paths
639            .get_gen_conf_filename()
640            .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
641        let chat_template_explicit = paths
642            .get_chat_template_explicit()
643            .as_ref()
644            .map(|x| x.to_string_lossy().to_string());
645        let chat_template = get_chat_template(
646            paths,
647            self.jinja_explicit.as_ref(),
648            chat_template_explicit.as_ref(),
649            self.chat_template.as_ref(),
650            None,
651        );
652
653        if let Some(calibration_file) = &self.config.calibration_file {
654            let calibration_data = std::fs::read_to_string(calibration_file)?;
655            // Tokenize, don't add bos yet
656            let tokens = tokenizer
657                .encode_fast(calibration_data, false)
658                .map_err(anyhow::Error::msg)?
659                .get_ids()
660                .to_vec();
661            info!(
662                "Collecting imatrix from calibration file `{}` of {} tokens.",
663                calibration_file.display(),
664                tokens.len()
665            );
666            let bos_toks = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();
667            let bos_tok_id = tokenizer
668                .token_to_id(&bos_toks[0])
669                .expect("Somehow the bos token is not present.");
670
671            // NOTE: We ONLY calibrate the text bits of these models!!
672            // So only those should be tracked!
673            model.begin_track_stats()?;
674
675            const CHUNK_SIZE: usize = 1024;
676            let n_chunks: usize = tokens.len().div_ceil(CHUNK_SIZE);
677            let start = Instant::now();
678            for (i, chunk) in tokens.chunks(CHUNK_SIZE).enumerate() {
679                let chunk = [vec![bos_tok_id], chunk.to_vec()].concat();
680                let chunk_len = chunk.len();
681
682                let start = Instant::now();
683                let inputs = make_prompt_chunk(
684                    0,
685                    vec![&chunk],
686                    &[0],
687                    &load_device,
688                    None,
689                    false,
690                    None,
691                    None,
692                )?;
693                let _ = model.forward(
694                    &inputs.input,
695                    None, // NOTE: We ONLY calibrate the text bits of these models!!
696                    &inputs.positions,
697                    inputs.context_lens,
698                    inputs.position_ids,
699                    model.default_model_specific_args(&inputs.input),
700                    None,
701                    &inputs.flash_meta,
702                )?;
703                match model.cache_mut() {
704                    EitherCache::Full(full) => {
705                        for layer in &mut *full.lock() {
706                            *layer = None
707                        }
708                    }
709                    EitherCache::Normal(normal) => {
710                        for layer in &mut *normal.lock().unwrap().0 {
711                            layer.reset();
712                        }
713                    }
714                    EitherCache::Hybrid(hybrid) => {
715                        hybrid.lock().unwrap().reset();
716                    }
717                }
718                let end = Instant::now();
719                info!(
720                    "Processed chunk {}/{n_chunks} ({chunk_len} tokens), {:.2}s",
721                    i + 1,
722                    end.duration_since(start).as_secs_f32()
723                );
724            }
725            load_device.synchronize()?;
726            let end = Instant::now();
727            info!(
728                "Finished collecting imatrix in {:.2}s",
729                end.duration_since(start).as_secs_f32()
730            );
731        }
732
733        let should_serialize = self.config.write_uqff.is_some();
734        let should_quantize_pass = loading_isq;
735
736        if (should_quantize_pass || should_serialize) && self.config.from_uqff.is_none() {
737            let imatrix_source = if should_quantize_pass {
738                match (
739                    self.config.imatrix.as_ref(),
740                    self.config.calibration_file.is_some(),
741                ) {
742                    (None, false) => None,
743                    (Some(file), false) => Some(ImatrixDataSource::File(file)),
744                    (None, true) => Some(ImatrixDataSource::Collected),
745                    (Some(_), true) => unreachable!(),
746                }
747            } else {
748                None
749            };
750            if should_quantize_pass {
751                info!("Applying ISQ to all ranks.");
752            } else {
753                info!("Serializing existing ISQ tensors without additional quantization.");
754            }
755            model.quantize(
756                in_situ_quant,
757                device.clone(),
758                self.config.topology.as_ref(),
759                silent,
760                imatrix_source,
761                IsqOrganization::Default,
762                should_quantize_pass,
763                self.config.write_uqff.as_ref(),
764                UqffFullSer {
765                    tokenizer: &tokenizer,
766                    template_filename: paths.get_template_filename(),
767                    generation_config: paths.get_gen_conf_filename(),
768                    config: config.clone(),
769                    processor_filename: paths.get_processor_config(),
770                    preprocessor_filename: paths.get_preprocessor_config(),
771                    modules: None,
772                    module_paths: None,
773                },
774                Arc::new(new_multi_progress()),
775            )?;
776        } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
777            model.load_from_artifacts(
778                device.clone(),
779                self.config.topology.as_ref(),
780                silent,
781                from_uqff,
782            )?;
783        }
784
785        let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
786            anyhow::ensure!(
787                !matches!(self.kind, ModelKind::Adapter { .. }),
788                "PagedAttention does not support adapter models."
789            );
790            let cache_config = calculate_cache_config(
791                paged_attn_config.mem_gpu,
792                paged_attn_config.block_size,
793                dtype,
794                paged_attn_config.cache_type,
795                model.config(),
796                &device,
797                &layer_devices,
798                silent,
799            )?;
800            let cache_engine =
801                CacheEngine::new(model.config(), &cache_config, dtype, &device, layer_devices)?;
802            (Some(cache_config), Some(cache_engine))
803        } else {
804            (None, None)
805        };
806
807        let max_seq_len = model.max_seq_len();
808        let llg_factory = build_llg_factory(tokenizer.clone())?;
809        let num_hidden_layers = match model.cache() {
810            EitherCache::Full(full) => full.lock().len(),
811            EitherCache::Normal(normal) => normal.lock().unwrap().0.len(),
812            EitherCache::Hybrid(hybrid) => hybrid.lock().unwrap().num_layers(),
813        };
814        let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
815        let sliding_window = model.config().sliding_window;
816        let model_metadata = Arc::new(model.config().clone());
817        Ok(Arc::new(Mutex::new(VisionPipeline {
818            model,
819            tokenizer: tokenizer.into(),
820            chat_template: Arc::new(chat_template),
821            model_id: self.model_id.clone(),
822            metadata: Arc::new(GeneralMetadata {
823                max_seq_len,
824                llg_factory: Some(llg_factory),
825                is_xlora: false,
826                num_hidden_layers,
827                eos_tok: eos,
828                kind: self.kind.clone(),
829                no_kv_cache: false,
830                no_prefix_cache: !self.inner.supports_prefix_cacher(&config),
831                activation_dtype: dtype,
832                sliding_window,
833                cache_config,
834                cache_engine,
835                model_metadata: Some(model_metadata),
836                modalities: self.inner.modalities(&config)?,
837            }),
838            processor,
839            prefixer: self.inner.prefixer(&config),
840            preprocessor_config: Arc::new(preprocessor_config),
841            topology: self.config.topology.clone(),
842            silent,
843            template_filename: paths.get_template_filename().clone(),
844            generation_config: paths.get_gen_conf_filename().cloned(),
845            config,
846            processor_filename: paths.get_processor_config().clone(),
847            preprocessor_filename: paths.get_preprocessor_config().clone(),
848            mapper: pipeline_mapper,
849            imatrix: self.config.imatrix.clone(),
850        })))
851    }
852
853    fn get_id(&self) -> String {
854        self.model_id.to_string()
855    }
856
857    fn get_kind(&self) -> ModelKind {
858        self.kind.clone()
859    }
860}
861
862impl PreProcessingMixin for VisionPipeline {
863    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
864        Some(self.chat_template.clone())
865    }
866    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
867        Some(self.preprocessor_config.clone())
868    }
869    fn get_processor(&self) -> Arc<dyn super::Processor> {
870        self.processor.clone()
871    }
872}
873
874impl IsqPipelineMixin for VisionPipeline {
875    fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
876        let device = self.device().clone();
877        self.model
878            .quantize(
879                Some(dtype),
880                device,
881                self.topology.as_ref(),
882                self.silent,
883                self.imatrix.as_ref().map(ImatrixDataSource::File),
884                IsqOrganization::Default,
885                true,
886                None,
887                UqffFullSer {
888                    tokenizer: &self.tokenizer,
889                    template_filename: &self.template_filename,
890                    generation_config: self.generation_config.as_ref(),
891                    config: self.config.clone(),
892                    processor_filename: &self.processor_filename,
893                    preprocessor_filename: &self.preprocessor_filename,
894                    modules: None,
895                    module_paths: None,
896                },
897                Arc::new(new_multi_progress()),
898            )
899            .map_err(anyhow::Error::msg)
900    }
901}
902
903impl CacheManagerMixin for VisionPipeline {
904    fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
905        if matches!(self.model.cache(), EitherCache::Full(_)) {
906            FullCacheManager.clone_in_cache(self, seqs, false)
907        } else {
908            NormalCacheManager.clone_in_cache(self, seqs, false)
909        }
910    }
911    fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
912        if matches!(self.model.cache(), EitherCache::Full(_)) {
913            FullCacheManager.clone_out_cache(self, seqs, false)
914        } else {
915            NormalCacheManager.clone_out_cache(self, seqs, false)
916        }
917    }
918    fn set_none_cache(
919        &self,
920        seqs: &mut [&mut Sequence],
921        reset_non_granular: bool,
922        modify_draft_cache: bool,
923
924        load_preallocated_cache: bool,
925    ) {
926        if matches!(self.model.cache(), EitherCache::Full(_)) {
927            FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false);
928        } else {
929            NormalCacheManager.set_none_cache(
930                self,
931                seqs,
932                modify_draft_cache,
933                load_preallocated_cache,
934            );
935        }
936        if reset_non_granular {
937            self.reset_non_granular_state()
938        }
939    }
940    fn cache(&self) -> &EitherCache {
941        self.model.cache()
942    }
943}
944
945impl MetadataMixin for VisionPipeline {
946    fn device(&self) -> Device {
947        self.model.device().clone()
948    }
949    fn get_metadata(&self) -> Arc<GeneralMetadata> {
950        self.metadata.clone()
951    }
952    fn name(&self) -> String {
953        self.model_id.clone()
954    }
955    fn reset_non_granular_state(&self) {}
956    fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
957        Some(self.tokenizer.clone())
958    }
959    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
960        Some(&*self.mapper)
961    }
962}
963
964#[async_trait::async_trait]
965impl Pipeline for VisionPipeline {
966    fn forward_inputs(
967        &mut self,
968        inputs: Box<dyn Any>,
969        return_raw_logits: bool,
970    ) -> candle_core::Result<ForwardInputsResult> {
971        let ModelInputs {
972            input_ids,
973            seqlen_offsets,
974            context_lens,
975            position_ids,
976            pixel_values,
977            model_specific_args,
978            paged_attn_meta,
979            flash_meta,
980        } = *inputs.downcast::<ModelInputs>().expect("Downcast failed.");
981        let metadata = self.get_metadata();
982        let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
983            (Some(engine), Some(meta)) => Some((engine.get_kv_cache().clone(), meta)),
984            (Some(_), None) => {
985                // This can happen if Rust-side user code is wrong
986                candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
987            }
988            (None, Some(_)) => {
989                // This should never happen but we handle it anyway
990                candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
991            }
992            (None, None) => None,
993        };
994        let logits = self.model.forward(
995            &input_ids,
996            pixel_values,
997            &seqlen_offsets,
998            context_lens,
999            position_ids,
1000            model_specific_args,
1001            paged_attn_meta,
1002            &flash_meta,
1003        )?;
1004        if return_raw_logits {
1005            Ok(ForwardInputsResult::RawLogits { logits })
1006        } else {
1007            Ok(ForwardInputsResult::CausalGeneration { logits })
1008        }
1009    }
1010    async fn sample_causal_gen(
1011        &self,
1012        seqs: &mut [&mut Sequence],
1013        logits: Vec<Tensor>,
1014        prefix_cacher: &mut PrefixCacheManagerV2,
1015        disable_eos_stop: bool,
1016        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
1017    ) -> Result<(), candle_core::Error> {
1018        sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
1019    }
1020    fn category(&self) -> ModelCategory {
1021        ModelCategory::Vision {
1022            prefixer: self.prefixer.clone(),
1023        }
1024    }
1025}
1026
1027impl AnyMoePipelineMixin for VisionPipeline {
1028    fn amoe_finish_training(&mut self, gate_model_id: Option<String>) -> candle_core::Result<()> {
1029        self.model.finish_training(gate_model_id)
1030    }
1031    fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
1032        self.model.get_vars()
1033    }
1034    fn amoe_base_model_trainable_params(&self) -> usize {
1035        self.model.trainable_params()
1036    }
1037    fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
1038        self.model.take_cached_gating_outputs()
1039    }
1040    fn amoe_create_layers(
1041        &mut self,
1042        model_ids: Vec<String>,
1043        token: &TokenSource,
1044        revision: Option<String>,
1045        match_regex: &str,
1046        config: crate::amoe::AnyMoeConfig,
1047        dtype: candle_core::DType,
1048        dev: &Device,
1049        (prefix, mlp): (String, String),
1050        layers: Vec<usize>,
1051        expert_type: AnyMoeExpertType,
1052        silent: bool,
1053        gate_model_id: Option<String>,
1054    ) -> candle_core::Result<()> {
1055        let mut vbs = Vec::new();
1056        // Precompile regex here
1057        let regex = Regex::new(match_regex).map_err(candle_core::Error::msg)?;
1058        for model_id in model_ids {
1059            let model_id_str = &model_id;
1060            let model_id = Path::new(&model_id);
1061
1062            let api = {
1063                let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1064                let mut api = ApiBuilder::from_cache(cache)
1065                    .with_progress(!silent)
1066                    .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1067                if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1068                    api = api.with_cache_dir(x.into());
1069                }
1070                api.build().map_err(candle_core::Error::msg)?
1071            };
1072            let revision = revision.clone().unwrap_or("main".to_string());
1073            let api = api.repo(Repo::with_revision(
1074                model_id_str.clone(),
1075                RepoType::Model,
1076                revision.clone(),
1077            ));
1078
1079            let mut filenames = vec![];
1080            for rfilename in
1081                api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1082            {
1083                filenames.push(api_get_file!(api, &rfilename, model_id));
1084            }
1085
1086            let regex = regex.clone();
1087            let match_regex_clone = match_regex.to_string();
1088            let layers_clone = layers.clone();
1089            let vb = from_mmaped_safetensors(
1090                filenames,
1091                vec![],
1092                Some(dtype),
1093                dev,
1094                vec![None],
1095                silent,
1096                None,
1097                move |key| {
1098                    if regex.is_match(&key) {
1099                        // Idx of the last char of the layer id, +1
1100                        // Assumes N.MLP
1101                        let last_layer_idx = key.find(&match_regex_clone).unwrap() - 1;
1102                        let first_layer_idx = key[..last_layer_idx].rfind('.').unwrap();
1103                        let layer_n = key[first_layer_idx + 1..last_layer_idx]
1104                            .parse::<usize>()
1105                            .unwrap();
1106                        layers_clone.contains(&layer_n) || layers_clone.is_empty()
1107                    } else {
1108                        false
1109                    }
1110                },
1111                Arc::new(|_| DeviceForLoadTensor::Base),
1112            )?;
1113            vbs.push(vb);
1114        }
1115
1116        let gate_vb = if let Some(gate_model_id) = gate_model_id {
1117            let model_id_str = &gate_model_id;
1118            let model_id = Path::new(&gate_model_id);
1119
1120            let api = {
1121                let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1122                let mut api = ApiBuilder::from_cache(cache)
1123                    .with_progress(!silent)
1124                    .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1125                if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1126                    api = api.with_cache_dir(x.into());
1127                }
1128                api.build().map_err(candle_core::Error::msg)?
1129            };
1130            let revision = revision.clone().unwrap_or("main".to_string());
1131            let api = api.repo(Repo::with_revision(
1132                model_id_str.clone(),
1133                RepoType::Model,
1134                revision.clone(),
1135            ));
1136
1137            let mut gate_filenames = vec![];
1138            for rfilename in
1139                api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1140            {
1141                gate_filenames.push(api_get_file!(api, &rfilename, model_id));
1142            }
1143            assert_eq!(
1144                gate_filenames.len(),
1145                1,
1146                "Gate model ID must contain only one .safetensors file"
1147            );
1148
1149            let vb = from_mmaped_safetensors(
1150                gate_filenames.clone(),
1151                vec![],
1152                Some(dtype),
1153                dev,
1154                vec![None],
1155                silent,
1156                None,
1157                |_| true,
1158                Arc::new(|_| DeviceForLoadTensor::Base),
1159            )?;
1160            info!(
1161                "Loaded gating layers from `{}`",
1162                gate_filenames[0].display()
1163            );
1164            Some(vb)
1165        } else {
1166            None
1167        };
1168
1169        self.model
1170            .create_anymoe_layers(vbs, config, (prefix, mlp), layers, expert_type, gate_vb)
1171    }
1172    fn amoe_supported(&self) -> bool {
1173        self.model.amoe_supported()
1174    }
1175}