mistralrs_core/pipeline/
vision.rs

1use super::isq::ImatrixDataSource;
2use super::isq::UqffFullSer;
3use super::{
4    get_model_paths, get_xlora_paths, AdapterKind, AnyMoePipelineMixin, AutoVisionLoader,
5    CacheManager, CacheManagerMixin, EitherCache, ForwardInputsResult, Gemma3Loader,
6    GeneralMetadata, IsqPipelineMixin, Loader, MetadataMixin, MiniCpmOLoader, ModelCategory,
7    ModelKind, ModelPaths, MultimodalPromptPrefixer, Phi4MMLoader, PreProcessingMixin, Processor,
8    Qwen2VLLoader, TokenSource, VLlama4Loader, VLlamaLoader, VisionModel, VisionModelLoader,
9};
10use super::{
11    Gemma3nLoader, Idefics2Loader, Idefics3Loader, LLaVALoader, LLaVANextLoader, Mistral3Loader,
12    Phi3VLoader, Qwen2_5VLLoader, VisionLoaderType,
13};
14use crate::device_map::{self, DeviceMapper};
15use crate::distributed::{self, WorkerTransferData};
16use crate::kv_cache::{FullCacheManager, NormalCacheManager};
17use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine};
18use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
19use crate::pipeline::llg::build_llg_factory;
20use crate::pipeline::loaders::auto_device_map;
21use crate::pipeline::loaders::QuantizationConfigShim;
22use crate::pipeline::sampling::sample_and_add_toks;
23use crate::pipeline::text_models_inputs_processor::make_prompt_chunk;
24use crate::pipeline::{get_chat_template, ChatTemplate, IsqOrganization, LocalModelPaths};
25use crate::prefix_cacher::PrefixCacheManagerV2;
26use crate::sequence::Sequence;
27use crate::utils::tokenizer::get_tokenizer;
28use crate::utils::varbuilder_utils::DeviceForLoadTensor;
29use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
30use crate::vision_models::preprocessor_config::PreProcessorConfig;
31use crate::vision_models::processor_config::ProcessorConfig;
32use crate::vision_models::ModelInputs;
33use crate::{
34    api_dir_list, api_get_file, get_paths, get_uqff_paths, vision_normal_model_loader,
35    vision_normal_model_loader_sharded, AnyMoeExpertType, DeviceMapSetting, Ordering,
36    PagedAttentionConfig, Pipeline, Topology, TryIntoDType, GLOBAL_HF_CACHE,
37};
38use anyhow::Result;
39use candle_core::{Device, Tensor, Var};
40use hf_hub::Cache;
41use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
42use indicatif::MultiProgress;
43use mistralrs_quant::log::once_log_info;
44use mistralrs_quant::{AfqLayer, GgufMatMul, HqqLayer, IsqType, QuantizedSerdeType};
45use rand_isaac::Isaac64Rng;
46use regex_automata::meta::Regex;
47use std::any::Any;
48use std::borrow::Cow;
49use std::num::NonZeroUsize;
50use std::path::{Path, PathBuf};
51use std::str::FromStr;
52use std::sync::{Arc, RwLock};
53use std::time::Instant;
54use std::{env, fs};
55use tokenizers::Tokenizer;
56use tokio::sync::Mutex;
57use tracing::{info, warn};
58
59pub struct VisionPipeline {
60    model: Box<dyn VisionModel + Send + Sync>,
61    tokenizer: Arc<Tokenizer>,
62    chat_template: Arc<ChatTemplate>,
63    model_id: String,
64    metadata: Arc<GeneralMetadata>,
65    processor: Arc<dyn Processor + Send + Sync>,
66    preprocessor_config: Arc<PreProcessorConfig>,
67    topology: Option<Topology>,
68    silent: bool,
69    prefixer: Arc<dyn MultimodalPromptPrefixer>,
70    mapper: Box<dyn DeviceMapper + Send + Sync>,
71
72    // For full UQFF serialization
73    template_filename: Option<PathBuf>,
74    generation_config: Option<PathBuf>,
75    config: String,
76    processor_filename: Option<PathBuf>,
77    preprocessor_filename: Option<PathBuf>,
78    imatrix: Option<PathBuf>,
79}
80
81/// A loader for a vision (non-quantized) model.
82pub struct VisionLoader {
83    inner: Box<dyn VisionModelLoader>,
84    model_id: String,
85    config: VisionSpecificConfig,
86    kind: ModelKind,
87    chat_template: Option<String>,
88    tokenizer_json: Option<String>,
89    xlora_model_id: Option<String>,
90    xlora_order: Option<Ordering>,
91    token_source: RwLock<Option<TokenSource>>,
92    revision: RwLock<Option<String>>,
93    from_uqff: RwLock<Option<Vec<PathBuf>>>,
94    jinja_explicit: Option<String>,
95    hf_cache_path: Option<PathBuf>,
96    lora_adapter_ids: Option<Vec<String>>,
97}
98
99#[derive(Default)]
100/// A builder for a loader for a vision (non-quantized) model.
101pub struct VisionLoaderBuilder {
102    model_id: Option<String>,
103    config: VisionSpecificConfig,
104    kind: ModelKind,
105    chat_template: Option<String>,
106    tokenizer_json: Option<String>,
107    jinja_explicit: Option<String>,
108    hf_cache_path: Option<PathBuf>,
109    lora_adapter_ids: Option<Vec<String>>,
110}
111
112#[derive(Clone, Default)]
113/// Config specific to loading a vision model.
114pub struct VisionSpecificConfig {
115    pub prompt_chunksize: Option<NonZeroUsize>,
116    pub topology: Option<Topology>,
117    pub write_uqff: Option<PathBuf>,
118    pub from_uqff: Option<Vec<PathBuf>>,
119    pub max_edge: Option<u32>,
120    pub imatrix: Option<PathBuf>,
121    pub calibration_file: Option<PathBuf>,
122    pub hf_cache_path: Option<PathBuf>,
123    pub matformer_config_path: Option<PathBuf>,
124    pub matformer_slice_name: Option<String>,
125}
126
127impl VisionLoaderBuilder {
128    pub fn new(
129        config: VisionSpecificConfig,
130        chat_template: Option<String>,
131        tokenizer_json: Option<String>,
132        model_id: Option<String>,
133        jinja_explicit: Option<String>,
134    ) -> Self {
135        Self {
136            config,
137            chat_template,
138            tokenizer_json,
139            model_id,
140            jinja_explicit,
141            kind: ModelKind::Normal,
142            hf_cache_path: None,
143            ..Default::default()
144        }
145    }
146
147    pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
148        self.hf_cache_path = Some(hf_cache_path);
149        self
150    }
151
152    pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
153        self.kind = ModelKind::Adapter {
154            adapter: AdapterKind::Lora,
155        };
156        self.lora_adapter_ids = Some(lora_adapter_ids);
157        self
158    }
159
160    pub fn build(self, loader: Option<VisionLoaderType>) -> Box<dyn Loader> {
161        let loader: Box<dyn VisionModelLoader> = match loader {
162            Some(VisionLoaderType::Phi3V) => Box::new(Phi3VLoader),
163            Some(VisionLoaderType::Idefics2) => Box::new(Idefics2Loader),
164            Some(VisionLoaderType::LLaVANext) => Box::new(LLaVANextLoader),
165            Some(VisionLoaderType::LLaVA) => Box::new(LLaVALoader),
166            Some(VisionLoaderType::VLlama) => Box::new(VLlamaLoader),
167            Some(VisionLoaderType::Qwen2VL) => Box::new(Qwen2VLLoader),
168            Some(VisionLoaderType::Idefics3) => Box::new(Idefics3Loader),
169            Some(VisionLoaderType::MiniCpmO) => Box::new(MiniCpmOLoader),
170            Some(VisionLoaderType::Phi4MM) => Box::new(Phi4MMLoader),
171            Some(VisionLoaderType::Qwen2_5VL) => Box::new(Qwen2_5VLLoader),
172            Some(VisionLoaderType::Gemma3) => Box::new(Gemma3Loader),
173            Some(VisionLoaderType::Mistral3) => Box::new(Mistral3Loader),
174            Some(VisionLoaderType::Llama4) => Box::new(VLlama4Loader),
175            Some(VisionLoaderType::Gemma3n) => Box::new(Gemma3nLoader),
176            None => Box::new(AutoVisionLoader),
177        };
178        Box::new(VisionLoader {
179            inner: loader,
180            model_id: self.model_id.unwrap(),
181            config: self.config,
182            kind: self.kind,
183            chat_template: self.chat_template,
184            tokenizer_json: self.tokenizer_json,
185            xlora_model_id: None,
186            xlora_order: None,
187            jinja_explicit: self.jinja_explicit,
188            token_source: RwLock::new(None),
189            revision: RwLock::new(None),
190            from_uqff: RwLock::new(None),
191            hf_cache_path: self.hf_cache_path,
192            lora_adapter_ids: self.lora_adapter_ids,
193        })
194    }
195}
196
197impl Loader for VisionLoader {
198    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
199    fn load_model_from_hf(
200        &self,
201        revision: Option<String>,
202        token_source: TokenSource,
203        dtype: &dyn TryIntoDType,
204        device: &Device,
205        silent: bool,
206        mapper: DeviceMapSetting,
207        in_situ_quant: Option<IsqType>,
208        paged_attn_config: Option<PagedAttentionConfig>,
209    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
210        let cache = self
211            .hf_cache_path
212            .clone()
213            .map(Cache::new)
214            .unwrap_or_default();
215        GLOBAL_HF_CACHE.get_or_init(|| cache);
216
217        let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths!(
218            LocalModelPaths,
219            &token_source,
220            revision.clone(),
221            self,
222            None,
223            None,
224            silent,
225            self.config.from_uqff.is_some()
226        );
227        if let Some(from_uqff) = self.config.from_uqff.clone() {
228            *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
229        }
230        *self
231            .token_source
232            .write()
233            .expect("Failed to write to token source") = Some(token_source);
234        *self.revision.write().expect("Failed to write to revision") = revision;
235        self.load_model_from_path(
236            &paths?,
237            dtype,
238            device,
239            silent,
240            mapper,
241            in_situ_quant,
242            paged_attn_config,
243        )
244    }
245
246    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
247    fn load_model_from_path(
248        &self,
249        paths: &Box<dyn ModelPaths>,
250        dtype: &dyn TryIntoDType,
251        device: &Device,
252        silent: bool,
253        mut mapper: DeviceMapSetting,
254        mut in_situ_quant: Option<IsqType>,
255        mut paged_attn_config: Option<PagedAttentionConfig>,
256    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
257        let config = std::fs::read_to_string(paths.get_config_filename())?;
258
259        if !self.inner.supports_paged_attention(&config) {
260            paged_attn_config = None;
261        }
262
263        let use_nccl = mistralrs_quant::distributed::use_nccl();
264
265        let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
266            let payload: WorkerTransferData = serde_json::from_str(&payload)?;
267            let WorkerTransferData::Init { id: _, worker_rank } = payload;
268            vec![candle_core::Device::new_cuda_with_stream(worker_rank + 1)?]
269        } else if use_nccl {
270            vec![candle_core::Device::new_cuda_with_stream(0)?]
271        } else {
272            device_map::get_all_similar_devices(device)?
273        };
274        let device = if use_nccl {
275            available_devices[0].clone()
276        } else {
277            device.clone()
278        };
279
280        // Load matformer slicing config if provided
281        let matformer_slicing_config = if let Some(matformer_path) =
282            &self.config.matformer_config_path
283        {
284            use crate::matformer::{MatformerConfig, MatformerSliceConfig};
285            info!("Loading Matformer config from {:?}", matformer_path);
286            let config = Arc::new(MatformerConfig::from_file(matformer_path)?);
287
288            if let Some(slice_name) = &self.config.matformer_slice_name {
289                info!("Using Matformer slice: {}", slice_name);
290                Some(MatformerSliceConfig::new(slice_name.clone(), config))
291            } else {
292                // If no slice name is provided but config exists, we'll need to handle this
293                // For now, return None and let the model handle the default slice selection
294                warn!("Matformer config loaded but no slice name specified. Models will use their default slice.");
295                None
296            }
297        } else {
298            None
299        };
300
301        // If auto, convert to Map if not using nccl
302        if use_nccl {
303            mapper = DeviceMapSetting::DummyNccl {
304                nm_device: available_devices[0].clone(),
305            };
306        } else if let DeviceMapSetting::Auto(mut params) = mapper.clone() {
307            // We can promote to vision params if we get text params
308            params = params.maybe_promote_to_vision();
309
310            // Initial dtype
311            let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
312
313            // Disable ISQ if we are loading a prequantized model.
314            if QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)? != 1 {
315                in_situ_quant = None;
316            }
317
318            // ISQ or UQFF: quantized path
319            // Match logic below where UQFF has priority
320            let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
321                if let Some(serialized) = &*self.from_uqff.read().unwrap() {
322                    let weight_pack_factor = {
323                        let ser_artifacts = unsafe {
324                            candle_core::safetensors::MmapedSafetensors::multi(serialized)?
325                        };
326                        let mut total_pack_factors = 0;
327                        let total_tensors = ser_artifacts.tensors().len();
328                        for (_, artifact) in ser_artifacts.tensors() {
329                            let artifact = artifact.data();
330                            // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
331                            let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
332                            let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
333                            {
334                                QuantizedSerdeType::Hqq => {
335                                    HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
336                                        .pack_factor(dtype)
337                                }
338                                QuantizedSerdeType::Gguf => {
339                                    GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
340                                        .pack_factor(dtype)
341                                }
342                                QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
343                                QuantizedSerdeType::Unquant => 1,
344                                QuantizedSerdeType::Afq => {
345                                    AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
346                                        .pack_factor(dtype)
347                                }
348                            };
349                            total_pack_factors += pack_factor;
350                        }
351
352                        total_pack_factors / total_tensors
353                    };
354
355                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
356                        &config,
357                        dtype,
358                        weight_pack_factor,
359                        matformer_slicing_config.as_ref(),
360                    )?;
361                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
362                        &config,
363                        dtype,
364                        weight_pack_factor,
365                        matformer_slicing_config.as_ref(),
366                    )?;
367                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
368                    (
369                        layer_sizes_in_bytes,
370                        non_mapped_size_in_bytes,
371                        layer_sizes_sum + non_mapped_size_in_bytes,
372                    )
373                } else if let Some(isq) = in_situ_quant {
374                    let weight_pack_factor = isq.pack_factor(dtype);
375                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
376                        &config,
377                        dtype,
378                        weight_pack_factor,
379                        matformer_slicing_config.as_ref(),
380                    )?;
381                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
382                        &config,
383                        dtype,
384                        weight_pack_factor,
385                        matformer_slicing_config.as_ref(),
386                    )?;
387                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
388                    (
389                        layer_sizes_in_bytes,
390                        non_mapped_size_in_bytes,
391                        layer_sizes_sum + non_mapped_size_in_bytes,
392                    )
393                } else {
394                    // Be sure to get the weight pack factor here; we might be loading a prequantized model.
395                    let weight_pack_factor =
396                        QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)?;
397                    let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
398                        &config,
399                        dtype,
400                        weight_pack_factor,
401                        matformer_slicing_config.as_ref(),
402                    )?;
403                    let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
404                        &config,
405                        dtype,
406                        weight_pack_factor,
407                        matformer_slicing_config.as_ref(),
408                    )?;
409                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
410                    (
411                        layer_sizes_in_bytes,
412                        non_mapped_size_in_bytes,
413                        layer_sizes_sum + non_mapped_size_in_bytes,
414                    )
415                };
416
417            let new = auto_device_map::get_device_layers(
418                &*self.inner,
419                &config,
420                self.inner.num_layers(&config)?,
421                layer_sizes_in_bytes,
422                non_mapped_size_in_bytes,
423                total_model_size_in_bytes,
424                &available_devices,
425                dtype,
426                &params,
427                params.max_seq_len(),
428                paged_attn_config.as_ref(),
429            )?;
430            mapper = DeviceMapSetting::Map(new);
431        }
432
433        let pipeline_mapper = mapper.into_mapper(
434            self.inner.num_layers(&config)?,
435            &device,
436            self.config.topology.as_ref(),
437        )?;
438        let mapper = mapper.into_mapper(
439            self.inner.num_layers(&config)?,
440            &device,
441            self.config.topology.as_ref(),
442        )?;
443        let mut layer_devices = Vec::new();
444        for layer in 0..self.inner.num_layers(&config)? {
445            let device = mapper.device_for(layer, false).cloned();
446            layer_devices.push(device);
447        }
448        let dtype = mapper.get_min_dtype(dtype)?;
449
450        // TODO: PagedAttention is not supported with CPU for now.
451        // This check is not really necessary because `get_device_layers` should prevent it.
452        let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
453        if mapping_uses_cpu && paged_attn_config.is_some() {
454            warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
455            paged_attn_config = None;
456        }
457
458        info!("Model config: {:?}", self.inner.get_config_repr(&config)?);
459        if crate::using_flash_attn() {
460            once_log_info("FlashAttention is enabled.");
461        }
462
463        // Logic for ISQ here: if no calibration (i.e imatrix), then allow immediate ISQ. Otherwise, back to normal.
464        let mut loading_isq = if self.config.imatrix.is_none()
465            && self.config.calibration_file.is_none()
466            && !device.is_cuda()
467            && self.config.write_uqff.is_none()
468            && in_situ_quant.is_some()
469        {
470            let predicates = self.inner.immediate_isq_predicates(&config)?;
471            info!("Applying ISQ to {in_situ_quant:?}");
472            if predicates.is_empty() {
473                warn!("No predicates for this model and ISQ setting detected. ISQ will not be applied to any weights!");
474            }
475            mistralrs_quant::set_immediate_isq(in_situ_quant, predicates);
476            false
477        } else {
478            in_situ_quant.is_some()
479        };
480
481        if let Some(ref topology) = self.config.topology {
482            loading_isq |= topology
483                .0
484                .iter()
485                .any(|layer| layer.as_ref().is_some_and(|layer| layer.isq.is_some()));
486        }
487
488        if self.config.imatrix.is_some() && self.config.calibration_file.is_some() {
489            anyhow::bail!(
490                "`imatrix` and `calibration_file` were both specified, this is not allowed."
491            );
492        }
493
494        // Load onto the regular device if not using isq or if the calibration file is specified
495        let load_device = if !loading_isq || self.config.calibration_file.is_some() {
496            loading_isq = false;
497            device.clone()
498        } else {
499            Device::Cpu
500        };
501
502        let attention_mechanism = if paged_attn_config.is_some() {
503            AttentionImplementation::PagedAttention
504        } else {
505            AttentionImplementation::Eager
506        };
507
508        let multi_progress = Arc::new(MultiProgress::new());
509
510        let mut model = if use_nccl {
511            let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
512                dtype,
513                &device,
514                &available_devices,
515                silent,
516                &config,
517                loading_isq,
518                self.config.from_uqff.is_some(),
519                IsqOrganization::Default,
520                &*self.inner,
521                paths.as_ref(),
522            )?;
523
524            // Special case for where things can be more optimially loaded.
525            match self.kind {
526                ModelKind::Normal => vision_normal_model_loader_sharded!(
527                    sharded_vb,
528                    config,
529                    self.inner,
530                    mapper,
531                    loading_isq,
532                    device.clone(),
533                    attention_mechanism,
534                    multi_progress.clone(),
535                    matformer_slicing_config.clone(),
536                ),
537                _ => unreachable!(),
538            }
539        } else {
540            match self.kind {
541                ModelKind::Normal => vision_normal_model_loader!(
542                    paths,
543                    Some(dtype),
544                    &load_device,
545                    layer_devices.clone(),
546                    config,
547                    self.inner,
548                    silent,
549                    mapper,
550                    loading_isq,
551                    self.config.from_uqff.is_some(),
552                    device.clone(),
553                    attention_mechanism,
554                    multi_progress,
555                    matformer_slicing_config.clone(),
556                ),
557                _ => unreachable!(),
558            }
559        };
560
561        // Handle the Gemma 3 1b case here
562        let preprocessor_config: PreProcessorConfig = match paths.get_preprocessor_config().as_ref()
563        {
564            Some(preprocessor_config) => {
565                serde_json::from_str(&fs::read_to_string(preprocessor_config).unwrap()).unwrap()
566            }
567            None => PreProcessorConfig::default(),
568        };
569        let processor_config: Option<ProcessorConfig> = paths
570            .get_processor_config()
571            .as_ref()
572            .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
573
574        let processor = self.inner.get_processor(
575            &config,
576            processor_config,
577            preprocessor_config.clone(),
578            self.config.max_edge,
579        ); //There are always some repos that don't properly handle config position, for example... LLaVA
580
581        let tokenizer = get_tokenizer(
582            paths.get_tokenizer_filename(),
583            Some(processor.get_special_tokens()),
584        )?;
585
586        let gen_conf: Option<GenerationConfig> = paths
587            .get_gen_conf_filename()
588            .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
589        let chat_template_explicit = paths
590            .get_chat_template_explicit()
591            .as_ref()
592            .map(|x| x.to_string_lossy().to_string());
593        let chat_template = get_chat_template(
594            paths,
595            self.jinja_explicit.as_ref(),
596            chat_template_explicit.as_ref(),
597            self.chat_template.as_ref(),
598            None,
599        );
600
601        if let Some(calibration_file) = &self.config.calibration_file {
602            let calibration_data = std::fs::read_to_string(calibration_file)?;
603            // Tokenize, don't add bos yet
604            let tokens = tokenizer
605                .encode_fast(calibration_data, false)
606                .map_err(anyhow::Error::msg)?
607                .get_ids()
608                .to_vec();
609            info!(
610                "Collecting imatrix from calibration file `{}` of {} tokens.",
611                calibration_file.display(),
612                tokens.len()
613            );
614            let bos_toks = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();
615            let bos_tok_id = tokenizer
616                .token_to_id(&bos_toks[0])
617                .expect("Somehow the bos token is not present.");
618
619            // NOTE: We ONLY calibrate the text bits of these models!!
620            // So only those should be tracked!
621            model.begin_track_stats()?;
622
623            const CHUNK_SIZE: usize = 1024;
624            let n_chunks: usize = tokens.len().div_ceil(CHUNK_SIZE);
625            let start = Instant::now();
626            for (i, chunk) in tokens.chunks(CHUNK_SIZE).enumerate() {
627                let chunk = [vec![bos_tok_id], chunk.to_vec()].concat();
628                let chunk_len = chunk.len();
629
630                let start = Instant::now();
631                let inputs = make_prompt_chunk(
632                    0,
633                    vec![&chunk],
634                    &[0],
635                    &load_device,
636                    None,
637                    false,
638                    None,
639                    None,
640                )?;
641                let _ = model.forward(
642                    &inputs.input,
643                    None, // NOTE: We ONLY calibrate the text bits of these models!!
644                    &inputs.positions,
645                    inputs.context_lens,
646                    inputs.position_ids,
647                    model.default_model_specific_args(&inputs.input),
648                    None,
649                    &inputs.flash_meta,
650                )?;
651                match model.cache_mut() {
652                    EitherCache::Full(full) => {
653                        for layer in &mut *full.lock() {
654                            *layer = None
655                        }
656                    }
657                    EitherCache::Normal(normal) => {
658                        for layer in &mut *normal.lock().unwrap().0 {
659                            layer.reset();
660                        }
661                    }
662                }
663                let end = Instant::now();
664                info!(
665                    "Processed chunk {}/{n_chunks} ({chunk_len} tokens), {:.2}s",
666                    i + 1,
667                    end.duration_since(start).as_secs_f32()
668                );
669            }
670            load_device.synchronize()?;
671            let end = Instant::now();
672            info!(
673                "Finished collecting imatrix in {:.2}s",
674                end.duration_since(start).as_secs_f32()
675            );
676        }
677
678        // Only if loading from UQFF
679        if (loading_isq || self.config.topology.is_some()) && self.config.from_uqff.is_none() {
680            let imatrix_source = match (
681                self.config.imatrix.as_ref(),
682                self.config.calibration_file.is_some(),
683            ) {
684                (None, false) => None,
685                (Some(file), false) => Some(ImatrixDataSource::File(file)),
686                (None, true) => Some(ImatrixDataSource::Collected),
687                (Some(_), true) => unreachable!(),
688            };
689            model.quantize(
690                in_situ_quant,
691                device.clone(),
692                self.config.topology.as_ref(),
693                silent,
694                imatrix_source,
695                IsqOrganization::Default,
696                self.config.write_uqff.as_ref(),
697                UqffFullSer {
698                    tokenizer: &tokenizer,
699                    template_filename: paths.get_template_filename(),
700                    generation_config: paths.get_gen_conf_filename(),
701                    config: config.clone(),
702                    processor_filename: paths.get_processor_config(),
703                    preprocessor_filename: paths.get_preprocessor_config(),
704                },
705                Arc::new(MultiProgress::new()),
706            )?;
707        } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
708            model.load_from_artifacts(
709                device.clone(),
710                self.config.topology.as_ref(),
711                silent,
712                from_uqff,
713            )?;
714        }
715
716        let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
717            anyhow::ensure!(
718                !matches!(self.kind, ModelKind::Adapter { .. }),
719                "PagedAttention does not support adapter models."
720            );
721            let cache_config = calculate_cache_config(
722                paged_attn_config.mem_gpu,
723                paged_attn_config.mem_cpu,
724                paged_attn_config.block_size,
725                dtype,
726                paged_attn_config.cache_type,
727                model.config(),
728                &device,
729                &layer_devices,
730                silent,
731            )?;
732            let cache_engine =
733                CacheEngine::new(model.config(), &cache_config, dtype, &device, layer_devices)?;
734            (Some(cache_config), Some(cache_engine))
735        } else {
736            (None, None)
737        };
738
739        let max_seq_len = model.max_seq_len();
740        let llg_factory = build_llg_factory(tokenizer.clone())?;
741        let num_hidden_layers = match model.cache() {
742            EitherCache::Full(full) => full.lock().len(),
743            EitherCache::Normal(normal) => normal.lock().unwrap().0.len(),
744        };
745        let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
746        let sliding_window = model.config().sliding_window;
747        let model_metadata = Arc::new(model.config().clone());
748        Ok(Arc::new(Mutex::new(VisionPipeline {
749            model,
750            tokenizer: tokenizer.into(),
751            chat_template: Arc::new(chat_template),
752            model_id: self.model_id.clone(),
753            metadata: Arc::new(GeneralMetadata {
754                max_seq_len,
755                llg_factory: Some(llg_factory),
756                is_xlora: false,
757                num_hidden_layers,
758                eos_tok: eos,
759                kind: self.kind.clone(),
760                no_kv_cache: false,
761                no_prefix_cache: !self.inner.supports_prefix_cacher(&config),
762                activation_dtype: dtype,
763                sliding_window,
764                cache_config,
765                cache_engine,
766                prompt_chunksize: self.config.prompt_chunksize,
767                model_metadata: Some(model_metadata),
768                modalities: self.inner.modalities(&config)?,
769            }),
770            processor,
771            prefixer: self.inner.prefixer(&config),
772            preprocessor_config: Arc::new(preprocessor_config),
773            topology: self.config.topology.clone(),
774            silent,
775            template_filename: paths.get_template_filename().clone(),
776            generation_config: paths.get_gen_conf_filename().cloned(),
777            config,
778            processor_filename: paths.get_processor_config().clone(),
779            preprocessor_filename: paths.get_preprocessor_config().clone(),
780            mapper: pipeline_mapper,
781            imatrix: self.config.imatrix.clone(),
782        })))
783    }
784
785    fn get_id(&self) -> String {
786        self.model_id.to_string()
787    }
788
789    fn get_kind(&self) -> ModelKind {
790        self.kind.clone()
791    }
792}
793
794impl PreProcessingMixin for VisionPipeline {
795    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
796        Some(self.chat_template.clone())
797    }
798    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
799        Some(self.preprocessor_config.clone())
800    }
801    fn get_processor(&self) -> Arc<dyn super::Processor> {
802        self.processor.clone()
803    }
804}
805
806impl IsqPipelineMixin for VisionPipeline {
807    fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
808        let device = self.device().clone();
809        self.model
810            .quantize(
811                Some(dtype),
812                device,
813                self.topology.as_ref(),
814                self.silent,
815                self.imatrix.as_ref().map(ImatrixDataSource::File),
816                IsqOrganization::Default,
817                None,
818                UqffFullSer {
819                    tokenizer: &self.tokenizer,
820                    template_filename: &self.template_filename,
821                    generation_config: self.generation_config.as_ref(),
822                    config: self.config.clone(),
823                    processor_filename: &self.processor_filename,
824                    preprocessor_filename: &self.preprocessor_filename,
825                },
826                Arc::new(MultiProgress::new()),
827            )
828            .map_err(anyhow::Error::msg)
829    }
830}
831
832impl CacheManagerMixin for VisionPipeline {
833    fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
834        if matches!(self.model.cache(), EitherCache::Full(_)) {
835            FullCacheManager.clone_in_cache(self, seqs, false)
836        } else {
837            NormalCacheManager.clone_in_cache(self, seqs, false)
838        }
839    }
840    fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
841        if matches!(self.model.cache(), EitherCache::Full(_)) {
842            FullCacheManager.clone_out_cache(self, seqs, false)
843        } else {
844            NormalCacheManager.clone_out_cache(self, seqs, false)
845        }
846    }
847    fn set_none_cache(
848        &self,
849        seqs: &mut [&mut Sequence],
850        reset_non_granular: bool,
851        modify_draft_cache: bool,
852
853        load_preallocated_cache: bool,
854    ) {
855        if matches!(self.model.cache(), EitherCache::Full(_)) {
856            FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false);
857        } else {
858            NormalCacheManager.set_none_cache(
859                self,
860                seqs,
861                modify_draft_cache,
862                load_preallocated_cache,
863            );
864        }
865        if reset_non_granular {
866            self.reset_non_granular_state()
867        }
868    }
869    fn cache(&self) -> &EitherCache {
870        self.model.cache()
871    }
872}
873
874impl MetadataMixin for VisionPipeline {
875    fn device(&self) -> Device {
876        self.model.device().clone()
877    }
878    fn get_metadata(&self) -> Arc<GeneralMetadata> {
879        self.metadata.clone()
880    }
881    fn name(&self) -> String {
882        self.model_id.clone()
883    }
884    fn reset_non_granular_state(&self) {}
885    fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
886        Some(self.tokenizer.clone())
887    }
888    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
889        Some(&*self.mapper)
890    }
891}
892
893#[async_trait::async_trait]
894impl Pipeline for VisionPipeline {
895    fn forward_inputs(
896        &mut self,
897        inputs: Box<dyn Any>,
898        return_raw_logits: bool,
899    ) -> candle_core::Result<ForwardInputsResult> {
900        let ModelInputs {
901            input_ids,
902            seqlen_offsets,
903            context_lens,
904            position_ids,
905            pixel_values,
906            model_specific_args,
907            paged_attn_meta,
908            flash_meta,
909        } = *inputs.downcast::<ModelInputs>().expect("Downcast failed.");
910        let metadata = self.get_metadata();
911        let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
912            (Some(engine), Some(meta)) => Some((engine.get_kv_cache().clone(), meta)),
913            (Some(_), None) => {
914                // This can happen if Rust-side user code is wrong
915                candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
916            }
917            (None, Some(_)) => {
918                // This should never happen but we handle it anyway
919                candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
920            }
921            (None, None) => None,
922        };
923        let logits = self.model.forward(
924            &input_ids,
925            pixel_values,
926            &seqlen_offsets,
927            context_lens,
928            position_ids,
929            model_specific_args,
930            paged_attn_meta,
931            &flash_meta,
932        )?;
933        if return_raw_logits {
934            Ok(ForwardInputsResult::RawLogits { logits })
935        } else {
936            Ok(ForwardInputsResult::CausalGeneration { logits })
937        }
938    }
939    async fn sample_causal_gen(
940        &self,
941        seqs: &mut [&mut Sequence],
942        logits: Vec<Tensor>,
943        prefix_cacher: &mut PrefixCacheManagerV2,
944        disable_eos_stop: bool,
945        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
946    ) -> Result<(), candle_core::Error> {
947        sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
948    }
949    fn category(&self) -> ModelCategory {
950        ModelCategory::Vision {
951            prefixer: self.prefixer.clone(),
952        }
953    }
954}
955
956impl AnyMoePipelineMixin for VisionPipeline {
957    fn amoe_finish_training(&mut self, gate_model_id: Option<String>) -> candle_core::Result<()> {
958        self.model.finish_training(gate_model_id)
959    }
960    fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
961        self.model.get_vars()
962    }
963    fn amoe_base_model_trainable_params(&self) -> usize {
964        self.model.trainable_params()
965    }
966    fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
967        self.model.take_cached_gating_outputs()
968    }
969    fn amoe_create_layers(
970        &mut self,
971        model_ids: Vec<String>,
972        token: &TokenSource,
973        revision: Option<String>,
974        match_regex: &str,
975        config: crate::amoe::AnyMoeConfig,
976        dtype: candle_core::DType,
977        dev: &Device,
978        (prefix, mlp): (String, String),
979        layers: Vec<usize>,
980        expert_type: AnyMoeExpertType,
981        silent: bool,
982        gate_model_id: Option<String>,
983    ) -> candle_core::Result<()> {
984        let mut vbs = Vec::new();
985        // Precompile regex here
986        let regex = Regex::new(match_regex).map_err(candle_core::Error::msg)?;
987        for model_id in model_ids {
988            let model_id_str = &model_id;
989            let model_id = Path::new(&model_id);
990
991            let api = {
992                let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
993                let mut api = ApiBuilder::from_cache(cache)
994                    .with_progress(!silent)
995                    .with_token(get_token(token).map_err(candle_core::Error::msg)?);
996                if let Ok(x) = std::env::var("HF_HUB_CACHE") {
997                    api = api.with_cache_dir(x.into());
998                }
999                api.build().map_err(candle_core::Error::msg)?
1000            };
1001            let revision = revision.clone().unwrap_or("main".to_string());
1002            let api = api.repo(Repo::with_revision(
1003                model_id_str.clone(),
1004                RepoType::Model,
1005                revision.clone(),
1006            ));
1007
1008            let mut filenames = vec![];
1009            for rfilename in
1010                api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1011            {
1012                filenames.push(api_get_file!(api, &rfilename, model_id));
1013            }
1014
1015            let regex = regex.clone();
1016            let match_regex_clone = match_regex.to_string();
1017            let layers_clone = layers.clone();
1018            let vb = from_mmaped_safetensors(
1019                filenames,
1020                vec![],
1021                Some(dtype),
1022                dev,
1023                vec![None],
1024                silent,
1025                None,
1026                move |key| {
1027                    if regex.is_match(&key) {
1028                        // Idx of the last char of the layer id, +1
1029                        // Assumes N.MLP
1030                        let last_layer_idx = key.find(&match_regex_clone).unwrap() - 1;
1031                        let first_layer_idx = key[..last_layer_idx].rfind('.').unwrap();
1032                        let layer_n = key[first_layer_idx + 1..last_layer_idx]
1033                            .parse::<usize>()
1034                            .unwrap();
1035                        layers_clone.contains(&layer_n) || layers_clone.is_empty()
1036                    } else {
1037                        false
1038                    }
1039                },
1040                Arc::new(|_| DeviceForLoadTensor::Base),
1041            )?;
1042            vbs.push(vb);
1043        }
1044
1045        let gate_vb = if let Some(gate_model_id) = gate_model_id {
1046            let model_id_str = &gate_model_id;
1047            let model_id = Path::new(&gate_model_id);
1048
1049            let api = {
1050                let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1051                let mut api = ApiBuilder::from_cache(cache)
1052                    .with_progress(!silent)
1053                    .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1054                if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1055                    api = api.with_cache_dir(x.into());
1056                }
1057                api.build().map_err(candle_core::Error::msg)?
1058            };
1059            let revision = revision.clone().unwrap_or("main".to_string());
1060            let api = api.repo(Repo::with_revision(
1061                model_id_str.clone(),
1062                RepoType::Model,
1063                revision.clone(),
1064            ));
1065
1066            let mut gate_filenames = vec![];
1067            for rfilename in
1068                api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1069            {
1070                gate_filenames.push(api_get_file!(api, &rfilename, model_id));
1071            }
1072            assert_eq!(
1073                gate_filenames.len(),
1074                1,
1075                "Gate model ID must contain only one .safetensors file"
1076            );
1077
1078            let vb = from_mmaped_safetensors(
1079                gate_filenames.clone(),
1080                vec![],
1081                Some(dtype),
1082                dev,
1083                vec![None],
1084                silent,
1085                None,
1086                |_| true,
1087                Arc::new(|_| DeviceForLoadTensor::Base),
1088            )?;
1089            info!(
1090                "Loaded gating layers from `{}`",
1091                gate_filenames[0].display()
1092            );
1093            Some(vb)
1094        } else {
1095            None
1096        };
1097
1098        self.model
1099            .create_anymoe_layers(vbs, config, (prefix, mlp), layers, expert_type, gate_vb)
1100    }
1101    fn amoe_supported(&self) -> bool {
1102        self.model.amoe_supported()
1103    }
1104}