mistralrs_core/pipeline/
vision.rs

1use super::isq::ImatrixDataSource;
2use super::isq::UqffFullSer;
3use super::{
4    get_model_paths, get_xlora_paths, AdapterKind, AnyMoePipelineMixin, AutoVisionLoader,
5    CacheManager, CacheManagerMixin, EitherCache, ForwardInputsResult, Gemma3Loader,
6    GeneralMetadata, IsqPipelineMixin, Loader, MetadataMixin, MiniCpmOLoader, ModelCategory,
7    ModelKind, ModelPaths, MultimodalPromptPrefixer, Phi4MMLoader, PreProcessingMixin, Processor,
8    Qwen2VLLoader, TokenSource, VLlama4Loader, VLlamaLoader, VisionModel, VisionModelLoader,
9};
10use super::{
11    Gemma3nLoader, Idefics2Loader, Idefics3Loader, LLaVALoader, LLaVANextLoader, Mistral3Loader,
12    Phi3VLoader, Qwen2_5VLLoader, VisionLoaderType,
13};
14use crate::attention::ATTENTION_CHUNK_SIZE;
15use crate::device_map::{self, DeviceMapper};
16use crate::distributed::{self, WorkerTransferData};
17use crate::kv_cache::{FullCacheManager, NormalCacheManager};
18use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine};
19use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
20use crate::pipeline::llg::build_llg_factory;
21use crate::pipeline::loaders::auto_device_map;
22use crate::pipeline::loaders::QuantizationConfigShim;
23use crate::pipeline::sampling::sample_and_add_toks;
24use crate::pipeline::text_models_inputs_processor::make_prompt_chunk;
25use crate::pipeline::{get_chat_template, ChatTemplate, IsqOrganization, LocalModelPaths};
26use crate::prefix_cacher::PrefixCacheManagerV2;
27use crate::sequence::Sequence;
28use crate::utils::tokenizer::get_tokenizer;
29use crate::utils::varbuilder_utils::DeviceForLoadTensor;
30use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
31use crate::vision_models::preprocessor_config::PreProcessorConfig;
32use crate::vision_models::processor_config::ProcessorConfig;
33use crate::vision_models::ModelInputs;
34use crate::{
35    api_dir_list, api_get_file, get_paths, get_uqff_paths, vision_normal_model_loader,
36    vision_normal_model_loader_sharded, AnyMoeExpertType, DeviceMapSetting, Ordering,
37    PagedAttentionConfig, Pipeline, Topology, TryIntoDType, GLOBAL_HF_CACHE,
38};
39use anyhow::Result;
40use candle_core::{Device, Tensor, Var};
41use hf_hub::Cache;
42use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
43use indicatif::MultiProgress;
44use mistralrs_quant::log::once_log_info;
45use mistralrs_quant::{AfqLayer, GgufMatMul, HqqLayer, IsqType, QuantizedSerdeType};
46use rand_isaac::Isaac64Rng;
47use regex_automata::meta::Regex;
48use std::any::Any;
49use std::borrow::Cow;
50use std::path::{Path, PathBuf};
51use std::str::FromStr;
52use std::sync::{Arc, RwLock};
53use std::time::Instant;
54use std::{env, fs};
55use tokenizers::Tokenizer;
56use tokio::sync::Mutex;
57use tracing::{info, warn};
58
59pub struct VisionPipeline {
60    model: Box<dyn VisionModel + Send + Sync>,
61    tokenizer: Arc<Tokenizer>,
62    chat_template: Arc<ChatTemplate>,
63    model_id: String,
64    metadata: Arc<GeneralMetadata>,
65    processor: Arc<dyn Processor + Send + Sync>,
66    preprocessor_config: Arc<PreProcessorConfig>,
67    topology: Option<Topology>,
68    silent: bool,
69    prefixer: Arc<dyn MultimodalPromptPrefixer>,
70    mapper: Box<dyn DeviceMapper + Send + Sync>,
71
72    // For full UQFF serialization
73    template_filename: Option<PathBuf>,
74    generation_config: Option<PathBuf>,
75    config: String,
76    processor_filename: Option<PathBuf>,
77    preprocessor_filename: Option<PathBuf>,
78    imatrix: Option<PathBuf>,
79}
80
81/// A loader for a vision (non-quantized) model.
82pub struct VisionLoader {
83    inner: Box<dyn VisionModelLoader>,
84    model_id: String,
85    config: VisionSpecificConfig,
86    kind: ModelKind,
87    chat_template: Option<String>,
88    tokenizer_json: Option<String>,
89    xlora_model_id: Option<String>,
90    xlora_order: Option<Ordering>,
91    token_source: RwLock<Option<TokenSource>>,
92    revision: RwLock<Option<String>>,
93    from_uqff: RwLock<Option<Vec<PathBuf>>>,
94    jinja_explicit: Option<String>,
95    hf_cache_path: Option<PathBuf>,
96    lora_adapter_ids: Option<Vec<String>>,
97}
98
99#[derive(Default)]
100/// A builder for a loader for a vision (non-quantized) model.
101pub struct VisionLoaderBuilder {
102    model_id: Option<String>,
103    config: VisionSpecificConfig,
104    kind: ModelKind,
105    chat_template: Option<String>,
106    tokenizer_json: Option<String>,
107    jinja_explicit: Option<String>,
108    hf_cache_path: Option<PathBuf>,
109    lora_adapter_ids: Option<Vec<String>>,
110}
111
112#[derive(Clone, Default)]
113/// Config specific to loading a vision model.
114pub struct VisionSpecificConfig {
115    pub topology: Option<Topology>,
116    pub write_uqff: Option<PathBuf>,
117    pub from_uqff: Option<Vec<PathBuf>>,
118    pub max_edge: Option<u32>,
119    pub imatrix: Option<PathBuf>,
120    pub calibration_file: Option<PathBuf>,
121    pub hf_cache_path: Option<PathBuf>,
122    pub matformer_config_path: Option<PathBuf>,
123    pub matformer_slice_name: Option<String>,
124}
125
126impl VisionLoaderBuilder {
127    pub fn new(
128        config: VisionSpecificConfig,
129        chat_template: Option<String>,
130        tokenizer_json: Option<String>,
131        model_id: Option<String>,
132        jinja_explicit: Option<String>,
133    ) -> Self {
134        Self {
135            config,
136            chat_template,
137            tokenizer_json,
138            model_id,
139            jinja_explicit,
140            kind: ModelKind::Normal,
141            hf_cache_path: None,
142            ..Default::default()
143        }
144    }
145
146    pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
147        self.hf_cache_path = Some(hf_cache_path);
148        self
149    }
150
151    pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
152        self.kind = ModelKind::Adapter {
153            adapter: AdapterKind::Lora,
154        };
155        self.lora_adapter_ids = Some(lora_adapter_ids);
156        self
157    }
158
159    pub fn build(self, loader: Option<VisionLoaderType>) -> Box<dyn Loader> {
160        let loader: Box<dyn VisionModelLoader> = match loader {
161            Some(VisionLoaderType::Phi3V) => Box::new(Phi3VLoader),
162            Some(VisionLoaderType::Idefics2) => Box::new(Idefics2Loader),
163            Some(VisionLoaderType::LLaVANext) => Box::new(LLaVANextLoader),
164            Some(VisionLoaderType::LLaVA) => Box::new(LLaVALoader),
165            Some(VisionLoaderType::VLlama) => Box::new(VLlamaLoader),
166            Some(VisionLoaderType::Qwen2VL) => Box::new(Qwen2VLLoader),
167            Some(VisionLoaderType::Idefics3) => Box::new(Idefics3Loader),
168            Some(VisionLoaderType::MiniCpmO) => Box::new(MiniCpmOLoader),
169            Some(VisionLoaderType::Phi4MM) => Box::new(Phi4MMLoader),
170            Some(VisionLoaderType::Qwen2_5VL) => Box::new(Qwen2_5VLLoader),
171            Some(VisionLoaderType::Gemma3) => Box::new(Gemma3Loader),
172            Some(VisionLoaderType::Mistral3) => Box::new(Mistral3Loader),
173            Some(VisionLoaderType::Llama4) => Box::new(VLlama4Loader),
174            Some(VisionLoaderType::Gemma3n) => Box::new(Gemma3nLoader),
175            None => Box::new(AutoVisionLoader),
176        };
177        Box::new(VisionLoader {
178            inner: loader,
179            model_id: self.model_id.unwrap(),
180            config: self.config,
181            kind: self.kind,
182            chat_template: self.chat_template,
183            tokenizer_json: self.tokenizer_json,
184            xlora_model_id: None,
185            xlora_order: None,
186            jinja_explicit: self.jinja_explicit,
187            token_source: RwLock::new(None),
188            revision: RwLock::new(None),
189            from_uqff: RwLock::new(None),
190            hf_cache_path: self.hf_cache_path,
191            lora_adapter_ids: self.lora_adapter_ids,
192        })
193    }
194}
195
196impl Loader for VisionLoader {
197    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
198    fn load_model_from_hf(
199        &self,
200        revision: Option<String>,
201        token_source: TokenSource,
202        dtype: &dyn TryIntoDType,
203        device: &Device,
204        silent: bool,
205        mapper: DeviceMapSetting,
206        in_situ_quant: Option<IsqType>,
207        paged_attn_config: Option<PagedAttentionConfig>,
208    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
209        let cache = self
210            .hf_cache_path
211            .clone()
212            .map(Cache::new)
213            .unwrap_or_default();
214        GLOBAL_HF_CACHE.get_or_init(|| cache);
215
216        let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths!(
217            LocalModelPaths,
218            &token_source,
219            revision.clone(),
220            self,
221            None,
222            None,
223            silent,
224            self.config.from_uqff.is_some()
225        );
226        if let Some(from_uqff) = self.config.from_uqff.clone() {
227            *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
228        }
229        *self
230            .token_source
231            .write()
232            .expect("Failed to write to token source") = Some(token_source);
233        *self.revision.write().expect("Failed to write to revision") = revision;
234        self.load_model_from_path(
235            &paths?,
236            dtype,
237            device,
238            silent,
239            mapper,
240            in_situ_quant,
241            paged_attn_config,
242        )
243    }
244
245    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
246    fn load_model_from_path(
247        &self,
248        paths: &Box<dyn ModelPaths>,
249        dtype: &dyn TryIntoDType,
250        device: &Device,
251        silent: bool,
252        mut mapper: DeviceMapSetting,
253        mut in_situ_quant: Option<IsqType>,
254        mut paged_attn_config: Option<PagedAttentionConfig>,
255    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
256        let config = std::fs::read_to_string(paths.get_config_filename())?;
257
258        if !self.inner.supports_paged_attention(&config) {
259            paged_attn_config = None;
260        }
261
262        info!("Prompt chunk size is {ATTENTION_CHUNK_SIZE}.");
263
264        let use_nccl = mistralrs_quant::distributed::use_nccl();
265
266        let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
267            let payload: WorkerTransferData = serde_json::from_str(&payload)?;
268            let WorkerTransferData::Init { id: _, worker_rank } = payload;
269            vec![candle_core::Device::new_cuda_with_stream(worker_rank + 1)?]
270        } else if use_nccl {
271            vec![candle_core::Device::new_cuda_with_stream(0)?]
272        } else {
273            device_map::get_all_similar_devices(device)?
274        };
275        #[cfg(feature = "cuda")]
276        for device in &available_devices {
277            if let Device::Cuda(dev) = device {
278                unsafe { dev.disable_event_tracking() };
279            }
280        }
281        let device = if use_nccl {
282            available_devices[0].clone()
283        } else {
284            device.clone()
285        };
286
287        // Load matformer slicing config if provided
288        let matformer_slicing_config = if let Some(matformer_path) =
289            &self.config.matformer_config_path
290        {
291            use crate::matformer::{MatformerConfig, MatformerSliceConfig};
292            info!("Loading Matformer config from {:?}", matformer_path);
293            let config = Arc::new(MatformerConfig::from_file(matformer_path)?);
294
295            if let Some(slice_name) = &self.config.matformer_slice_name {
296                info!("Using Matformer slice: {}", slice_name);
297                Some(MatformerSliceConfig::new(slice_name.clone(), config))
298            } else {
299                // If no slice name is provided but config exists, we'll need to handle this
300                // For now, return None and let the model handle the default slice selection
301                warn!("Matformer config loaded but no slice name specified. Models will use their default slice.");
302                None
303            }
304        } else {
305            None
306        };
307
308        // If auto, convert to Map if not using nccl
309        if use_nccl {
310            mapper = DeviceMapSetting::DummyNccl {
311                nm_device: available_devices[0].clone(),
312            };
313        } else if let DeviceMapSetting::Auto(mut params) = mapper.clone() {
314            // We can promote to vision params if we get text params
315            params = params.maybe_promote_to_vision();
316
317            // Initial dtype
318            let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
319
320            // Disable ISQ if we are loading a prequantized model.
321            if QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)? != 1 {
322                in_situ_quant = None;
323            }
324
325            // ISQ or UQFF: quantized path
326            // Match logic below where UQFF has priority
327            let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
328                if let Some(serialized) = &*self.from_uqff.read().unwrap() {
329                    let weight_pack_factor = {
330                        let ser_artifacts = unsafe {
331                            candle_core::safetensors::MmapedSafetensors::multi(serialized)?
332                        };
333                        let mut total_pack_factors = 0;
334                        let total_tensors = ser_artifacts.tensors().len();
335                        for (_, artifact) in ser_artifacts.tensors() {
336                            let artifact = artifact.data();
337                            // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
338                            let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
339                            let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
340                            {
341                                QuantizedSerdeType::Hqq => {
342                                    HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
343                                        .pack_factor(dtype)
344                                }
345                                QuantizedSerdeType::Gguf => {
346                                    GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
347                                        .pack_factor(dtype)
348                                }
349                                QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
350                                QuantizedSerdeType::Unquant => 1,
351                                QuantizedSerdeType::Afq => {
352                                    AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
353                                        .pack_factor(dtype)
354                                }
355                            };
356                            total_pack_factors += pack_factor;
357                        }
358
359                        total_pack_factors / total_tensors
360                    };
361
362                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
363                        &config,
364                        dtype,
365                        weight_pack_factor,
366                        matformer_slicing_config.as_ref(),
367                    )?;
368                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
369                        &config,
370                        dtype,
371                        weight_pack_factor,
372                        matformer_slicing_config.as_ref(),
373                    )?;
374                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
375                    (
376                        layer_sizes_in_bytes,
377                        non_mapped_size_in_bytes,
378                        layer_sizes_sum + non_mapped_size_in_bytes,
379                    )
380                } else if let Some(isq) = in_situ_quant {
381                    let weight_pack_factor = isq.pack_factor(dtype);
382                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
383                        &config,
384                        dtype,
385                        weight_pack_factor,
386                        matformer_slicing_config.as_ref(),
387                    )?;
388                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
389                        &config,
390                        dtype,
391                        weight_pack_factor,
392                        matformer_slicing_config.as_ref(),
393                    )?;
394                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
395                    (
396                        layer_sizes_in_bytes,
397                        non_mapped_size_in_bytes,
398                        layer_sizes_sum + non_mapped_size_in_bytes,
399                    )
400                } else {
401                    // Be sure to get the weight pack factor here; we might be loading a prequantized model.
402                    let weight_pack_factor =
403                        QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)?;
404                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
405                        &config,
406                        dtype,
407                        weight_pack_factor,
408                        matformer_slicing_config.as_ref(),
409                    )?;
410                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
411                        &config,
412                        dtype,
413                        weight_pack_factor,
414                        matformer_slicing_config.as_ref(),
415                    )?;
416                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
417                    (
418                        layer_sizes_in_bytes,
419                        non_mapped_size_in_bytes,
420                        layer_sizes_sum + non_mapped_size_in_bytes,
421                    )
422                };
423
424            let new = auto_device_map::get_device_layers(
425                &*self.inner,
426                &config,
427                self.inner.num_layers(&config)?,
428                layer_sizes_in_bytes,
429                non_mapped_size_in_bytes,
430                total_model_size_in_bytes,
431                &available_devices,
432                dtype,
433                &params,
434                paged_attn_config.as_ref(),
435            )?;
436            mapper = DeviceMapSetting::Map(new);
437        }
438
439        let pipeline_mapper = mapper.into_mapper(
440            self.inner.num_layers(&config)?,
441            &device,
442            self.config.topology.as_ref(),
443        )?;
444        let mapper = mapper.into_mapper(
445            self.inner.num_layers(&config)?,
446            &device,
447            self.config.topology.as_ref(),
448        )?;
449        let mut layer_devices = Vec::new();
450        for layer in 0..self.inner.num_layers(&config)? {
451            let device = mapper.device_for(layer, false).cloned();
452            layer_devices.push(device);
453        }
454        let dtype = mapper.get_min_dtype(dtype)?;
455
456        // TODO: PagedAttention is not supported with CPU for now.
457        // This check is not really necessary because `get_device_layers` should prevent it.
458        let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
459        if mapping_uses_cpu && paged_attn_config.is_some() {
460            warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
461            paged_attn_config = None;
462        }
463
464        info!("Model config: {:?}", self.inner.get_config_repr(&config)?);
465        if crate::using_flash_attn() {
466            once_log_info("FlashAttention is enabled.");
467        }
468
469        // Logic for ISQ here: if no calibration (i.e imatrix), then allow immediate ISQ. Otherwise, back to normal.
470        let mut loading_isq = if self.config.imatrix.is_none()
471            && self.config.calibration_file.is_none()
472            && !device.is_cuda()
473            && self.config.write_uqff.is_none()
474            && in_situ_quant.is_some()
475        {
476            let predicates = self.inner.immediate_isq_predicates(&config)?;
477            info!("Applying ISQ to {in_situ_quant:?}");
478            if predicates.is_empty() {
479                warn!("No predicates for this model and ISQ setting detected. ISQ will not be applied to any weights!");
480            }
481            mistralrs_quant::set_immediate_isq(in_situ_quant, predicates);
482            false
483        } else {
484            in_situ_quant.is_some()
485        };
486
487        if let Some(ref topology) = self.config.topology {
488            loading_isq |= topology
489                .0
490                .iter()
491                .any(|layer| layer.as_ref().is_some_and(|layer| layer.isq.is_some()));
492        }
493
494        if self.config.imatrix.is_some() && self.config.calibration_file.is_some() {
495            anyhow::bail!(
496                "`imatrix` and `calibration_file` were both specified, this is not allowed."
497            );
498        }
499
500        // Load onto the regular device if not using isq or if the calibration file is specified
501        let load_device = if !loading_isq || self.config.calibration_file.is_some() {
502            loading_isq = false;
503            device.clone()
504        } else {
505            Device::Cpu
506        };
507
508        let attention_mechanism = if paged_attn_config.is_some() {
509            AttentionImplementation::PagedAttention
510        } else {
511            AttentionImplementation::Eager
512        };
513
514        let multi_progress = Arc::new(MultiProgress::new());
515
516        let mut model = if use_nccl {
517            let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
518                dtype,
519                &device,
520                &available_devices,
521                silent,
522                &config,
523                loading_isq,
524                self.config.from_uqff.is_some(),
525                IsqOrganization::Default,
526                &*self.inner,
527                paths.as_ref(),
528            )?;
529
530            // Special case for where things can be more optimially loaded.
531            match self.kind {
532                ModelKind::Normal => vision_normal_model_loader_sharded!(
533                    sharded_vb,
534                    config,
535                    self.inner,
536                    mapper,
537                    loading_isq,
538                    device.clone(),
539                    attention_mechanism,
540                    multi_progress.clone(),
541                    matformer_slicing_config.clone(),
542                ),
543                _ => unreachable!(),
544            }
545        } else {
546            match self.kind {
547                ModelKind::Normal => vision_normal_model_loader!(
548                    paths,
549                    Some(dtype),
550                    &load_device,
551                    layer_devices.clone(),
552                    config,
553                    self.inner,
554                    silent,
555                    mapper,
556                    loading_isq,
557                    self.config.from_uqff.is_some(),
558                    device.clone(),
559                    attention_mechanism,
560                    multi_progress,
561                    matformer_slicing_config.clone(),
562                ),
563                _ => unreachable!(),
564            }
565        };
566
567        // Handle the Gemma 3 1b case here
568        let preprocessor_config: PreProcessorConfig = match paths.get_preprocessor_config().as_ref()
569        {
570            Some(preprocessor_config) => {
571                serde_json::from_str(&fs::read_to_string(preprocessor_config).unwrap()).unwrap()
572            }
573            None => PreProcessorConfig::default(),
574        };
575        let processor_config: Option<ProcessorConfig> = paths
576            .get_processor_config()
577            .as_ref()
578            .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
579
580        let processor = self.inner.get_processor(
581            &config,
582            processor_config,
583            preprocessor_config.clone(),
584            self.config.max_edge,
585        ); //There are always some repos that don't properly handle config position, for example... LLaVA
586
587        let tokenizer = get_tokenizer(
588            paths.get_tokenizer_filename(),
589            Some(processor.get_special_tokens()),
590        )?;
591
592        let gen_conf: Option<GenerationConfig> = paths
593            .get_gen_conf_filename()
594            .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
595        let chat_template_explicit = paths
596            .get_chat_template_explicit()
597            .as_ref()
598            .map(|x| x.to_string_lossy().to_string());
599        let chat_template = get_chat_template(
600            paths,
601            self.jinja_explicit.as_ref(),
602            chat_template_explicit.as_ref(),
603            self.chat_template.as_ref(),
604            None,
605        );
606
607        if let Some(calibration_file) = &self.config.calibration_file {
608            let calibration_data = std::fs::read_to_string(calibration_file)?;
609            // Tokenize, don't add bos yet
610            let tokens = tokenizer
611                .encode_fast(calibration_data, false)
612                .map_err(anyhow::Error::msg)?
613                .get_ids()
614                .to_vec();
615            info!(
616                "Collecting imatrix from calibration file `{}` of {} tokens.",
617                calibration_file.display(),
618                tokens.len()
619            );
620            let bos_toks = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();
621            let bos_tok_id = tokenizer
622                .token_to_id(&bos_toks[0])
623                .expect("Somehow the bos token is not present.");
624
625            // NOTE: We ONLY calibrate the text bits of these models!!
626            // So only those should be tracked!
627            model.begin_track_stats()?;
628
629            const CHUNK_SIZE: usize = 1024;
630            let n_chunks: usize = tokens.len().div_ceil(CHUNK_SIZE);
631            let start = Instant::now();
632            for (i, chunk) in tokens.chunks(CHUNK_SIZE).enumerate() {
633                let chunk = [vec![bos_tok_id], chunk.to_vec()].concat();
634                let chunk_len = chunk.len();
635
636                let start = Instant::now();
637                let inputs = make_prompt_chunk(
638                    0,
639                    vec![&chunk],
640                    &[0],
641                    &load_device,
642                    None,
643                    false,
644                    None,
645                    None,
646                )?;
647                let _ = model.forward(
648                    &inputs.input,
649                    None, // NOTE: We ONLY calibrate the text bits of these models!!
650                    &inputs.positions,
651                    inputs.context_lens,
652                    inputs.position_ids,
653                    model.default_model_specific_args(&inputs.input),
654                    None,
655                    &inputs.flash_meta,
656                )?;
657                match model.cache_mut() {
658                    EitherCache::Full(full) => {
659                        for layer in &mut *full.lock() {
660                            *layer = None
661                        }
662                    }
663                    EitherCache::Normal(normal) => {
664                        for layer in &mut *normal.lock().unwrap().0 {
665                            layer.reset();
666                        }
667                    }
668                }
669                let end = Instant::now();
670                info!(
671                    "Processed chunk {}/{n_chunks} ({chunk_len} tokens), {:.2}s",
672                    i + 1,
673                    end.duration_since(start).as_secs_f32()
674                );
675            }
676            load_device.synchronize()?;
677            let end = Instant::now();
678            info!(
679                "Finished collecting imatrix in {:.2}s",
680                end.duration_since(start).as_secs_f32()
681            );
682        }
683
684        // Only if loading from UQFF
685        if (loading_isq || self.config.topology.is_some()) && self.config.from_uqff.is_none() {
686            let imatrix_source = match (
687                self.config.imatrix.as_ref(),
688                self.config.calibration_file.is_some(),
689            ) {
690                (None, false) => None,
691                (Some(file), false) => Some(ImatrixDataSource::File(file)),
692                (None, true) => Some(ImatrixDataSource::Collected),
693                (Some(_), true) => unreachable!(),
694            };
695            model.quantize(
696                in_situ_quant,
697                device.clone(),
698                self.config.topology.as_ref(),
699                silent,
700                imatrix_source,
701                IsqOrganization::Default,
702                self.config.write_uqff.as_ref(),
703                UqffFullSer {
704                    tokenizer: &tokenizer,
705                    template_filename: paths.get_template_filename(),
706                    generation_config: paths.get_gen_conf_filename(),
707                    config: config.clone(),
708                    processor_filename: paths.get_processor_config(),
709                    preprocessor_filename: paths.get_preprocessor_config(),
710                },
711                Arc::new(MultiProgress::new()),
712            )?;
713        } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
714            model.load_from_artifacts(
715                device.clone(),
716                self.config.topology.as_ref(),
717                silent,
718                from_uqff,
719            )?;
720        }
721
722        let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
723            anyhow::ensure!(
724                !matches!(self.kind, ModelKind::Adapter { .. }),
725                "PagedAttention does not support adapter models."
726            );
727            let cache_config = calculate_cache_config(
728                paged_attn_config.mem_gpu,
729                paged_attn_config.mem_cpu,
730                paged_attn_config.block_size,
731                dtype,
732                paged_attn_config.cache_type,
733                model.config(),
734                &device,
735                &layer_devices,
736                silent,
737            )?;
738            let cache_engine =
739                CacheEngine::new(model.config(), &cache_config, dtype, &device, layer_devices)?;
740            (Some(cache_config), Some(cache_engine))
741        } else {
742            (None, None)
743        };
744
745        let max_seq_len = model.max_seq_len();
746        let llg_factory = build_llg_factory(tokenizer.clone())?;
747        let num_hidden_layers = match model.cache() {
748            EitherCache::Full(full) => full.lock().len(),
749            EitherCache::Normal(normal) => normal.lock().unwrap().0.len(),
750        };
751        let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
752        let sliding_window = model.config().sliding_window;
753        let model_metadata = Arc::new(model.config().clone());
754        Ok(Arc::new(Mutex::new(VisionPipeline {
755            model,
756            tokenizer: tokenizer.into(),
757            chat_template: Arc::new(chat_template),
758            model_id: self.model_id.clone(),
759            metadata: Arc::new(GeneralMetadata {
760                max_seq_len,
761                llg_factory: Some(llg_factory),
762                is_xlora: false,
763                num_hidden_layers,
764                eos_tok: eos,
765                kind: self.kind.clone(),
766                no_kv_cache: false,
767                no_prefix_cache: !self.inner.supports_prefix_cacher(&config),
768                activation_dtype: dtype,
769                sliding_window,
770                cache_config,
771                cache_engine,
772                model_metadata: Some(model_metadata),
773                modalities: self.inner.modalities(&config)?,
774            }),
775            processor,
776            prefixer: self.inner.prefixer(&config),
777            preprocessor_config: Arc::new(preprocessor_config),
778            topology: self.config.topology.clone(),
779            silent,
780            template_filename: paths.get_template_filename().clone(),
781            generation_config: paths.get_gen_conf_filename().cloned(),
782            config,
783            processor_filename: paths.get_processor_config().clone(),
784            preprocessor_filename: paths.get_preprocessor_config().clone(),
785            mapper: pipeline_mapper,
786            imatrix: self.config.imatrix.clone(),
787        })))
788    }
789
790    fn get_id(&self) -> String {
791        self.model_id.to_string()
792    }
793
794    fn get_kind(&self) -> ModelKind {
795        self.kind.clone()
796    }
797}
798
799impl PreProcessingMixin for VisionPipeline {
800    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
801        Some(self.chat_template.clone())
802    }
803    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
804        Some(self.preprocessor_config.clone())
805    }
806    fn get_processor(&self) -> Arc<dyn super::Processor> {
807        self.processor.clone()
808    }
809}
810
811impl IsqPipelineMixin for VisionPipeline {
812    fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
813        let device = self.device().clone();
814        self.model
815            .quantize(
816                Some(dtype),
817                device,
818                self.topology.as_ref(),
819                self.silent,
820                self.imatrix.as_ref().map(ImatrixDataSource::File),
821                IsqOrganization::Default,
822                None,
823                UqffFullSer {
824                    tokenizer: &self.tokenizer,
825                    template_filename: &self.template_filename,
826                    generation_config: self.generation_config.as_ref(),
827                    config: self.config.clone(),
828                    processor_filename: &self.processor_filename,
829                    preprocessor_filename: &self.preprocessor_filename,
830                },
831                Arc::new(MultiProgress::new()),
832            )
833            .map_err(anyhow::Error::msg)
834    }
835}
836
837impl CacheManagerMixin for VisionPipeline {
838    fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
839        if matches!(self.model.cache(), EitherCache::Full(_)) {
840            FullCacheManager.clone_in_cache(self, seqs, false)
841        } else {
842            NormalCacheManager.clone_in_cache(self, seqs, false)
843        }
844    }
845    fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
846        if matches!(self.model.cache(), EitherCache::Full(_)) {
847            FullCacheManager.clone_out_cache(self, seqs, false)
848        } else {
849            NormalCacheManager.clone_out_cache(self, seqs, false)
850        }
851    }
852    fn set_none_cache(
853        &self,
854        seqs: &mut [&mut Sequence],
855        reset_non_granular: bool,
856        modify_draft_cache: bool,
857
858        load_preallocated_cache: bool,
859    ) {
860        if matches!(self.model.cache(), EitherCache::Full(_)) {
861            FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false);
862        } else {
863            NormalCacheManager.set_none_cache(
864                self,
865                seqs,
866                modify_draft_cache,
867                load_preallocated_cache,
868            );
869        }
870        if reset_non_granular {
871            self.reset_non_granular_state()
872        }
873    }
874    fn cache(&self) -> &EitherCache {
875        self.model.cache()
876    }
877}
878
879impl MetadataMixin for VisionPipeline {
880    fn device(&self) -> Device {
881        self.model.device().clone()
882    }
883    fn get_metadata(&self) -> Arc<GeneralMetadata> {
884        self.metadata.clone()
885    }
886    fn name(&self) -> String {
887        self.model_id.clone()
888    }
889    fn reset_non_granular_state(&self) {}
890    fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
891        Some(self.tokenizer.clone())
892    }
893    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
894        Some(&*self.mapper)
895    }
896}
897
898#[async_trait::async_trait]
899impl Pipeline for VisionPipeline {
900    fn forward_inputs(
901        &mut self,
902        inputs: Box<dyn Any>,
903        return_raw_logits: bool,
904    ) -> candle_core::Result<ForwardInputsResult> {
905        let ModelInputs {
906            input_ids,
907            seqlen_offsets,
908            context_lens,
909            position_ids,
910            pixel_values,
911            model_specific_args,
912            paged_attn_meta,
913            flash_meta,
914        } = *inputs.downcast::<ModelInputs>().expect("Downcast failed.");
915        let metadata = self.get_metadata();
916        let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
917            (Some(engine), Some(meta)) => Some((engine.get_kv_cache().clone(), meta)),
918            (Some(_), None) => {
919                // This can happen if Rust-side user code is wrong
920                candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
921            }
922            (None, Some(_)) => {
923                // This should never happen but we handle it anyway
924                candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
925            }
926            (None, None) => None,
927        };
928        let logits = self.model.forward(
929            &input_ids,
930            pixel_values,
931            &seqlen_offsets,
932            context_lens,
933            position_ids,
934            model_specific_args,
935            paged_attn_meta,
936            &flash_meta,
937        )?;
938        if return_raw_logits {
939            Ok(ForwardInputsResult::RawLogits { logits })
940        } else {
941            Ok(ForwardInputsResult::CausalGeneration { logits })
942        }
943    }
944    async fn sample_causal_gen(
945        &self,
946        seqs: &mut [&mut Sequence],
947        logits: Vec<Tensor>,
948        prefix_cacher: &mut PrefixCacheManagerV2,
949        disable_eos_stop: bool,
950        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
951    ) -> Result<(), candle_core::Error> {
952        sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
953    }
954    fn category(&self) -> ModelCategory {
955        ModelCategory::Vision {
956            prefixer: self.prefixer.clone(),
957        }
958    }
959}
960
961impl AnyMoePipelineMixin for VisionPipeline {
962    fn amoe_finish_training(&mut self, gate_model_id: Option<String>) -> candle_core::Result<()> {
963        self.model.finish_training(gate_model_id)
964    }
965    fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
966        self.model.get_vars()
967    }
968    fn amoe_base_model_trainable_params(&self) -> usize {
969        self.model.trainable_params()
970    }
971    fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
972        self.model.take_cached_gating_outputs()
973    }
974    fn amoe_create_layers(
975        &mut self,
976        model_ids: Vec<String>,
977        token: &TokenSource,
978        revision: Option<String>,
979        match_regex: &str,
980        config: crate::amoe::AnyMoeConfig,
981        dtype: candle_core::DType,
982        dev: &Device,
983        (prefix, mlp): (String, String),
984        layers: Vec<usize>,
985        expert_type: AnyMoeExpertType,
986        silent: bool,
987        gate_model_id: Option<String>,
988    ) -> candle_core::Result<()> {
989        let mut vbs = Vec::new();
990        // Precompile regex here
991        let regex = Regex::new(match_regex).map_err(candle_core::Error::msg)?;
992        for model_id in model_ids {
993            let model_id_str = &model_id;
994            let model_id = Path::new(&model_id);
995
996            let api = {
997                let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
998                let mut api = ApiBuilder::from_cache(cache)
999                    .with_progress(!silent)
1000                    .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1001                if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1002                    api = api.with_cache_dir(x.into());
1003                }
1004                api.build().map_err(candle_core::Error::msg)?
1005            };
1006            let revision = revision.clone().unwrap_or("main".to_string());
1007            let api = api.repo(Repo::with_revision(
1008                model_id_str.clone(),
1009                RepoType::Model,
1010                revision.clone(),
1011            ));
1012
1013            let mut filenames = vec![];
1014            for rfilename in
1015                api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1016            {
1017                filenames.push(api_get_file!(api, &rfilename, model_id));
1018            }
1019
1020            let regex = regex.clone();
1021            let match_regex_clone = match_regex.to_string();
1022            let layers_clone = layers.clone();
1023            let vb = from_mmaped_safetensors(
1024                filenames,
1025                vec![],
1026                Some(dtype),
1027                dev,
1028                vec![None],
1029                silent,
1030                None,
1031                move |key| {
1032                    if regex.is_match(&key) {
1033                        // Idx of the last char of the layer id, +1
1034                        // Assumes N.MLP
1035                        let last_layer_idx = key.find(&match_regex_clone).unwrap() - 1;
1036                        let first_layer_idx = key[..last_layer_idx].rfind('.').unwrap();
1037                        let layer_n = key[first_layer_idx + 1..last_layer_idx]
1038                            .parse::<usize>()
1039                            .unwrap();
1040                        layers_clone.contains(&layer_n) || layers_clone.is_empty()
1041                    } else {
1042                        false
1043                    }
1044                },
1045                Arc::new(|_| DeviceForLoadTensor::Base),
1046            )?;
1047            vbs.push(vb);
1048        }
1049
1050        let gate_vb = if let Some(gate_model_id) = gate_model_id {
1051            let model_id_str = &gate_model_id;
1052            let model_id = Path::new(&gate_model_id);
1053
1054            let api = {
1055                let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1056                let mut api = ApiBuilder::from_cache(cache)
1057                    .with_progress(!silent)
1058                    .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1059                if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1060                    api = api.with_cache_dir(x.into());
1061                }
1062                api.build().map_err(candle_core::Error::msg)?
1063            };
1064            let revision = revision.clone().unwrap_or("main".to_string());
1065            let api = api.repo(Repo::with_revision(
1066                model_id_str.clone(),
1067                RepoType::Model,
1068                revision.clone(),
1069            ));
1070
1071            let mut gate_filenames = vec![];
1072            for rfilename in
1073                api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1074            {
1075                gate_filenames.push(api_get_file!(api, &rfilename, model_id));
1076            }
1077            assert_eq!(
1078                gate_filenames.len(),
1079                1,
1080                "Gate model ID must contain only one .safetensors file"
1081            );
1082
1083            let vb = from_mmaped_safetensors(
1084                gate_filenames.clone(),
1085                vec![],
1086                Some(dtype),
1087                dev,
1088                vec![None],
1089                silent,
1090                None,
1091                |_| true,
1092                Arc::new(|_| DeviceForLoadTensor::Base),
1093            )?;
1094            info!(
1095                "Loaded gating layers from `{}`",
1096                gate_filenames[0].display()
1097            );
1098            Some(vb)
1099        } else {
1100            None
1101        };
1102
1103        self.model
1104            .create_anymoe_layers(vbs, config, (prefix, mlp), layers, expert_type, gate_vb)
1105    }
1106    fn amoe_supported(&self) -> bool {
1107        self.model.amoe_supported()
1108    }
1109}