mistralrs_core/pipeline/
vision.rs

1use super::isq::ImatrixDataSource;
2use super::isq::UqffFullSer;
3use super::{
4    get_model_paths, get_xlora_paths, AdapterKind, AnyMoePipelineMixin, AutoVisionLoader,
5    CacheManager, CacheManagerMixin, EitherCache, ForwardInputsResult, Gemma3Loader,
6    GeneralMetadata, IsqPipelineMixin, Loader, MetadataMixin, MiniCpmOLoader, ModelCategory,
7    ModelKind, ModelPaths, Phi4MMLoader, PreProcessingMixin, Processor, Qwen2VLLoader, TokenSource,
8    VLlama4Loader, VLlamaLoader, VisionModel, VisionModelLoader, VisionPromptPrefixer,
9};
10use super::{
11    Idefics2Loader, Idefics3Loader, LLaVALoader, LLaVANextLoader, Mistral3Loader, Phi3VLoader,
12    Qwen2_5VLLoader, VisionLoaderType,
13};
14use crate::device_map::{self, DeviceMapper};
15use crate::distributed::{self, WorkerTransferData};
16use crate::kv_cache::{FullCacheManager, NormalCacheManager};
17use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine};
18use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
19use crate::pipeline::llg::build_llg_factory;
20use crate::pipeline::loaders::QuantizationConfigShim;
21use crate::pipeline::sampling::sample_and_add_toks;
22use crate::pipeline::text_models_inputs_processor::make_prompt_chunk;
23use crate::pipeline::{get_chat_template, ChatTemplate, IsqOrganization, LocalModelPaths};
24use crate::prefix_cacher::PrefixCacheManagerV2;
25use crate::sequence::Sequence;
26use crate::utils::tokenizer::get_tokenizer;
27use crate::utils::varbuilder_utils::DeviceForLoadTensor;
28use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
29use crate::vision_models::preprocessor_config::PreProcessorConfig;
30use crate::vision_models::processor_config::ProcessorConfig;
31use crate::vision_models::ModelInputs;
32use crate::{
33    api_dir_list, api_get_file, get_paths, get_uqff_paths, vision_normal_model_loader,
34    vision_normal_model_loader_sharded, AnyMoeExpertType, DeviceMapSetting, Ordering,
35    PagedAttentionConfig, Pipeline, Topology, TryIntoDType, GLOBAL_HF_CACHE,
36};
37use anyhow::Result;
38use candle_core::{Device, Tensor, Var};
39use hf_hub::Cache;
40use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
41use indicatif::MultiProgress;
42use mistralrs_quant::log::once_log_info;
43use mistralrs_quant::{AfqLayer, GgufMatMul, HqqLayer, IsqType, QuantizedSerdeType};
44use rand_isaac::Isaac64Rng;
45use regex_automata::meta::Regex;
46use std::any::Any;
47use std::borrow::Cow;
48use std::num::NonZeroUsize;
49use std::path::{Path, PathBuf};
50use std::str::FromStr;
51use std::sync::{Arc, RwLock};
52use std::time::Instant;
53use std::{env, fs};
54use tokenizers::Tokenizer;
55use tokio::sync::Mutex;
56use tracing::{info, warn};
57
58pub struct VisionPipeline {
59    model: Box<dyn VisionModel + Send + Sync>,
60    tokenizer: Arc<Tokenizer>,
61    chat_template: Arc<ChatTemplate>,
62    model_id: String,
63    metadata: Arc<GeneralMetadata>,
64    processor: Arc<dyn Processor + Send + Sync>,
65    preprocessor_config: Arc<PreProcessorConfig>,
66    topology: Option<Topology>,
67    silent: bool,
68    prefixer: Arc<dyn VisionPromptPrefixer>,
69    mapper: Box<dyn DeviceMapper + Send + Sync>,
70
71    // For full UQFF serialization
72    template_filename: Option<PathBuf>,
73    generation_config: Option<PathBuf>,
74    config: String,
75    processor_filename: Option<PathBuf>,
76    preprocessor_filename: Option<PathBuf>,
77    imatrix: Option<PathBuf>,
78}
79
80/// A loader for a vision (non-quantized) model.
81pub struct VisionLoader {
82    inner: Box<dyn VisionModelLoader>,
83    model_id: String,
84    config: VisionSpecificConfig,
85    kind: ModelKind,
86    chat_template: Option<String>,
87    tokenizer_json: Option<String>,
88    xlora_model_id: Option<String>,
89    xlora_order: Option<Ordering>,
90    token_source: RwLock<Option<TokenSource>>,
91    revision: RwLock<Option<String>>,
92    from_uqff: RwLock<Option<Vec<PathBuf>>>,
93    jinja_explicit: Option<String>,
94    hf_cache_path: Option<PathBuf>,
95    lora_adapter_ids: Option<Vec<String>>,
96}
97
98#[derive(Default)]
99/// A builder for a loader for a vision (non-quantized) model.
100pub struct VisionLoaderBuilder {
101    model_id: Option<String>,
102    config: VisionSpecificConfig,
103    kind: ModelKind,
104    chat_template: Option<String>,
105    tokenizer_json: Option<String>,
106    jinja_explicit: Option<String>,
107    hf_cache_path: Option<PathBuf>,
108    lora_adapter_ids: Option<Vec<String>>,
109}
110
111#[derive(Clone, Default)]
112/// Config specific to loading a vision model.
113pub struct VisionSpecificConfig {
114    pub prompt_chunksize: Option<NonZeroUsize>,
115    pub topology: Option<Topology>,
116    pub write_uqff: Option<PathBuf>,
117    pub from_uqff: Option<Vec<PathBuf>>,
118    pub max_edge: Option<u32>,
119    pub imatrix: Option<PathBuf>,
120    pub calibration_file: Option<PathBuf>,
121    pub hf_cache_path: Option<PathBuf>,
122}
123
124impl VisionLoaderBuilder {
125    pub fn new(
126        config: VisionSpecificConfig,
127        chat_template: Option<String>,
128        tokenizer_json: Option<String>,
129        model_id: Option<String>,
130        jinja_explicit: Option<String>,
131    ) -> Self {
132        Self {
133            config,
134            chat_template,
135            tokenizer_json,
136            model_id,
137            jinja_explicit,
138            kind: ModelKind::Normal,
139            hf_cache_path: None,
140            ..Default::default()
141        }
142    }
143
144    pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
145        self.hf_cache_path = Some(hf_cache_path);
146        self
147    }
148
149    pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
150        self.kind = ModelKind::Adapter {
151            adapter: AdapterKind::Lora,
152        };
153        self.lora_adapter_ids = Some(lora_adapter_ids);
154        self
155    }
156
157    pub fn build(self, loader: Option<VisionLoaderType>) -> Box<dyn Loader> {
158        let loader: Box<dyn VisionModelLoader> = match loader {
159            Some(VisionLoaderType::Phi3V) => Box::new(Phi3VLoader),
160            Some(VisionLoaderType::Idefics2) => Box::new(Idefics2Loader),
161            Some(VisionLoaderType::LLaVANext) => Box::new(LLaVANextLoader),
162            Some(VisionLoaderType::LLaVA) => Box::new(LLaVALoader),
163            Some(VisionLoaderType::VLlama) => Box::new(VLlamaLoader),
164            Some(VisionLoaderType::Qwen2VL) => Box::new(Qwen2VLLoader),
165            Some(VisionLoaderType::Idefics3) => Box::new(Idefics3Loader),
166            Some(VisionLoaderType::MiniCpmO) => Box::new(MiniCpmOLoader),
167            Some(VisionLoaderType::Phi4MM) => Box::new(Phi4MMLoader),
168            Some(VisionLoaderType::Qwen2_5VL) => Box::new(Qwen2_5VLLoader),
169            Some(VisionLoaderType::Gemma3) => Box::new(Gemma3Loader),
170            Some(VisionLoaderType::Mistral3) => Box::new(Mistral3Loader),
171            Some(VisionLoaderType::Llama4) => Box::new(VLlama4Loader),
172            None => Box::new(AutoVisionLoader),
173        };
174        Box::new(VisionLoader {
175            inner: loader,
176            model_id: self.model_id.unwrap(),
177            config: self.config,
178            kind: self.kind,
179            chat_template: self.chat_template,
180            tokenizer_json: self.tokenizer_json,
181            xlora_model_id: None,
182            xlora_order: None,
183            jinja_explicit: self.jinja_explicit,
184            token_source: RwLock::new(None),
185            revision: RwLock::new(None),
186            from_uqff: RwLock::new(None),
187            hf_cache_path: self.hf_cache_path,
188            lora_adapter_ids: self.lora_adapter_ids,
189        })
190    }
191}
192
193impl Loader for VisionLoader {
194    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
195    fn load_model_from_hf(
196        &self,
197        revision: Option<String>,
198        token_source: TokenSource,
199        dtype: &dyn TryIntoDType,
200        device: &Device,
201        silent: bool,
202        mapper: DeviceMapSetting,
203        in_situ_quant: Option<IsqType>,
204        paged_attn_config: Option<PagedAttentionConfig>,
205    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
206        let cache = self
207            .hf_cache_path
208            .clone()
209            .map(Cache::new)
210            .unwrap_or_default();
211        GLOBAL_HF_CACHE.get_or_init(|| cache);
212
213        let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths!(
214            LocalModelPaths,
215            &token_source,
216            revision.clone(),
217            self,
218            None,
219            None,
220            silent,
221            self.config.from_uqff.is_some()
222        );
223        if let Some(from_uqff) = self.config.from_uqff.clone() {
224            *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
225        }
226        *self
227            .token_source
228            .write()
229            .expect("Failed to write to token source") = Some(token_source);
230        *self.revision.write().expect("Failed to write to revision") = revision;
231        self.load_model_from_path(
232            &paths?,
233            dtype,
234            device,
235            silent,
236            mapper,
237            in_situ_quant,
238            paged_attn_config,
239        )
240    }
241
242    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
243    fn load_model_from_path(
244        &self,
245        paths: &Box<dyn ModelPaths>,
246        dtype: &dyn TryIntoDType,
247        device: &Device,
248        silent: bool,
249        mut mapper: DeviceMapSetting,
250        mut in_situ_quant: Option<IsqType>,
251        mut paged_attn_config: Option<PagedAttentionConfig>,
252    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
253        let config = std::fs::read_to_string(paths.get_config_filename())?;
254
255        if !self.inner.supports_paged_attention(&config) {
256            paged_attn_config = None;
257        }
258
259        let use_nccl = mistralrs_quant::distributed::use_nccl();
260
261        let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
262            let payload: WorkerTransferData = serde_json::from_str(&payload)?;
263            let WorkerTransferData::Init { id: _, worker_rank } = payload;
264            vec![candle_core::Device::new_cuda_with_stream(worker_rank + 1)?]
265        } else if use_nccl {
266            vec![candle_core::Device::new_cuda_with_stream(0)?]
267        } else {
268            device_map::get_all_similar_devices(device)?
269        };
270        let device = if use_nccl {
271            available_devices[0].clone()
272        } else {
273            device.clone()
274        };
275
276        // If auto, convert to Map if not using nccl
277        if use_nccl {
278            mapper = DeviceMapSetting::DummyNccl {
279                nm_device: available_devices[0].clone(),
280            };
281        } else if let DeviceMapSetting::Auto(params) = mapper.clone() {
282            // Initial dtype
283            let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
284
285            // Disable ISQ if we are loading a prequantized model.
286            if QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)? != 1 {
287                in_situ_quant = None;
288            }
289
290            // ISQ or UQFF: quantized path
291            // Match logic below where UQFF has priority
292            let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
293                if let Some(serialized) = &*self.from_uqff.read().unwrap() {
294                    let weight_pack_factor = {
295                        let ser_artifacts = unsafe {
296                            candle_core::safetensors::MmapedSafetensors::multi(serialized)?
297                        };
298                        let mut total_pack_factors = 0;
299                        let total_tensors = ser_artifacts.tensors().len();
300                        for (_, artifact) in ser_artifacts.tensors() {
301                            let artifact = artifact.data();
302                            // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
303                            let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
304                            let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
305                            {
306                                QuantizedSerdeType::Hqq => {
307                                    HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
308                                        .pack_factor(dtype)
309                                }
310                                QuantizedSerdeType::Gguf => {
311                                    GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
312                                        .pack_factor(dtype)
313                                }
314                                QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
315                                QuantizedSerdeType::Unquant => 1,
316                                QuantizedSerdeType::Afq => {
317                                    AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
318                                        .pack_factor(dtype)
319                                }
320                            };
321                            total_pack_factors += pack_factor;
322                        }
323
324                        total_pack_factors / total_tensors
325                    };
326
327                    let layer_sizes_in_bytes =
328                        self.inner
329                            .layer_sizes_in_bytes(&config, dtype, weight_pack_factor)?;
330                    let non_mapped_size_in_bytes =
331                        self.inner
332                            .non_mapped_size_in_bytes(&config, dtype, weight_pack_factor)?;
333                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
334                    (
335                        layer_sizes_in_bytes,
336                        non_mapped_size_in_bytes,
337                        layer_sizes_sum + non_mapped_size_in_bytes,
338                    )
339                } else if let Some(isq) = in_situ_quant {
340                    let weight_pack_factor = isq.pack_factor(dtype);
341                    let layer_sizes_in_bytes =
342                        self.inner
343                            .layer_sizes_in_bytes(&config, dtype, weight_pack_factor)?;
344                    let non_mapped_size_in_bytes =
345                        self.inner
346                            .non_mapped_size_in_bytes(&config, dtype, weight_pack_factor)?;
347                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
348                    (
349                        layer_sizes_in_bytes,
350                        non_mapped_size_in_bytes,
351                        layer_sizes_sum + non_mapped_size_in_bytes,
352                    )
353                } else {
354                    // Be sure to get the weight pack factor here; we might be loading a prequantized model.
355                    let weight_pack_factor =
356                        QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)?;
357                    let layer_sizes_in_bytes =
358                        self.inner
359                            .layer_sizes_in_bytes(&config, dtype, weight_pack_factor)?;
360                    let non_mapped_size_in_bytes =
361                        self.inner
362                            .non_mapped_size_in_bytes(&config, dtype, weight_pack_factor)?;
363                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
364                    (
365                        layer_sizes_in_bytes,
366                        non_mapped_size_in_bytes,
367                        layer_sizes_sum + non_mapped_size_in_bytes,
368                    )
369                };
370
371            let new = self.inner.get_device_layers(
372                &config,
373                self.inner.num_layers(&config)?,
374                layer_sizes_in_bytes,
375                non_mapped_size_in_bytes,
376                total_model_size_in_bytes,
377                &available_devices,
378                dtype,
379                &params,
380                params.max_seq_len(),
381                paged_attn_config.as_ref(),
382            )?;
383            mapper = DeviceMapSetting::Map(new);
384        }
385
386        let pipeline_mapper = mapper.into_mapper(
387            self.inner.num_layers(&config)?,
388            &device,
389            self.config.topology.as_ref(),
390        )?;
391        let mapper = mapper.into_mapper(
392            self.inner.num_layers(&config)?,
393            &device,
394            self.config.topology.as_ref(),
395        )?;
396        let mut layer_devices = Vec::new();
397        for layer in 0..self.inner.num_layers(&config)? {
398            let device = mapper.device_for(layer, false).cloned();
399            layer_devices.push(device);
400        }
401        let dtype = mapper.get_min_dtype(dtype)?;
402
403        // TODO: PagedAttention is not supported with CPU for now.
404        // This check is not really necessary because `get_device_layers` should prevent it.
405        let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
406        if mapping_uses_cpu {
407            warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
408            paged_attn_config = None;
409        }
410
411        info!("Model config: {:?}", self.inner.get_config_repr(&config)?);
412        if crate::using_flash_attn() {
413            once_log_info("FlashAttention is enabled.");
414        }
415
416        // Logic for ISQ here: if no calibration (i.e imatrix), then allow immediate ISQ. Otherwise, back to normal.
417        let mut loading_isq = if self.config.imatrix.is_none()
418            && self.config.calibration_file.is_none()
419            && !device.is_cuda()
420            && self.config.write_uqff.is_none()
421            && in_situ_quant.is_some()
422        {
423            let predicates = self.inner.immediate_isq_predicates(&config)?;
424            info!("Applying ISQ to {in_situ_quant:?}");
425            if predicates.is_empty() {
426                warn!("No predicates for this model and ISQ setting detected. ISQ will not be applied to any weights!");
427            }
428            mistralrs_quant::set_immediate_isq(in_situ_quant, predicates);
429            false
430        } else {
431            in_situ_quant.is_some()
432        };
433
434        if let Some(ref topology) = self.config.topology {
435            loading_isq |= topology
436                .0
437                .iter()
438                .any(|layer| layer.as_ref().is_some_and(|layer| layer.isq.is_some()));
439        }
440
441        if self.config.imatrix.is_some() && self.config.calibration_file.is_some() {
442            anyhow::bail!(
443                "`imatrix` and `calibration_file` were both specified, this is not allowed."
444            );
445        }
446
447        // Load onto the regular device if not using isq or if the calibration file is specified
448        let load_device = if !loading_isq || self.config.calibration_file.is_some() {
449            loading_isq = false;
450            device.clone()
451        } else {
452            Device::Cpu
453        };
454
455        let attention_mechanism = if paged_attn_config.is_some() {
456            AttentionImplementation::PagedAttention
457        } else {
458            AttentionImplementation::Eager
459        };
460
461        let multi_progress = Arc::new(MultiProgress::new());
462
463        let mut model = if use_nccl {
464            let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
465                dtype,
466                &device,
467                &available_devices,
468                silent,
469                &config,
470                loading_isq,
471                self.config.from_uqff.is_some(),
472                IsqOrganization::Default,
473                &*self.inner,
474                paths.as_ref(),
475            )?;
476
477            // Special case for where things can be more optimially loaded.
478            match self.kind {
479                ModelKind::Normal => vision_normal_model_loader_sharded!(
480                    sharded_vb,
481                    config,
482                    self.inner,
483                    mapper,
484                    loading_isq,
485                    device.clone(),
486                    attention_mechanism,
487                    multi_progress.clone(),
488                ),
489                _ => unreachable!(),
490            }
491        } else {
492            match self.kind {
493                ModelKind::Normal => vision_normal_model_loader!(
494                    paths,
495                    Some(dtype),
496                    &load_device,
497                    layer_devices.clone(),
498                    config,
499                    self.inner,
500                    silent,
501                    mapper,
502                    loading_isq,
503                    self.config.from_uqff.is_some(),
504                    device.clone(),
505                    attention_mechanism,
506                    multi_progress,
507                ),
508                _ => unreachable!(),
509            }
510        };
511
512        // Handle the Gemma 3 1b case here
513        let preprocessor_config: PreProcessorConfig = match paths.get_preprocessor_config().as_ref()
514        {
515            Some(preprocessor_config) => {
516                serde_json::from_str(&fs::read_to_string(preprocessor_config).unwrap()).unwrap()
517            }
518            None => PreProcessorConfig::default(),
519        };
520        let processor_config: Option<ProcessorConfig> = paths
521            .get_processor_config()
522            .as_ref()
523            .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
524
525        let processor = self.inner.get_processor(
526            &config,
527            processor_config,
528            preprocessor_config.clone(),
529            self.config.max_edge,
530        ); //There are always some repos that don't properly handle config position, for example... LLaVA
531
532        let tokenizer = get_tokenizer(
533            paths.get_tokenizer_filename(),
534            Some(processor.get_special_tokens()),
535        )?;
536
537        let gen_conf: Option<GenerationConfig> = paths.get_gen_conf_filename().map(|f| {
538            serde_json::from_str(&fs::read_to_string(f).unwrap())
539                .expect("bos_token_id/eos_token_id missing in generation_config.json")
540        });
541        let chat_template_explicit = paths
542            .get_chat_template_explicit()
543            .as_ref()
544            .map(|x| x.to_string_lossy().to_string());
545        let chat_template = get_chat_template(
546            paths,
547            self.jinja_explicit.as_ref(),
548            chat_template_explicit.as_ref(),
549            self.chat_template.as_ref(),
550            None,
551        );
552
553        if let Some(calibration_file) = &self.config.calibration_file {
554            let calibration_data = std::fs::read_to_string(calibration_file)?;
555            // Tokenize, don't add bos yet
556            let tokens = tokenizer
557                .encode_fast(calibration_data, false)
558                .map_err(anyhow::Error::msg)?
559                .get_ids()
560                .to_vec();
561            info!(
562                "Collecting imatrix from calibration file `{}` of {} tokens.",
563                calibration_file.display(),
564                tokens.len()
565            );
566            let bos_toks = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();
567            let bos_tok_id = tokenizer
568                .token_to_id(&bos_toks[0])
569                .expect("Somehow the bos token is not present.");
570
571            // NOTE: We ONLY calibrate the text bits of these models!!
572            // So only those should be tracked!
573            model.begin_track_stats()?;
574
575            const CHUNK_SIZE: usize = 1024;
576            let n_chunks: usize = tokens.len().div_ceil(CHUNK_SIZE);
577            let start = Instant::now();
578            for (i, chunk) in tokens.chunks(CHUNK_SIZE).enumerate() {
579                let chunk = [vec![bos_tok_id], chunk.to_vec()].concat();
580                let chunk_len = chunk.len();
581
582                let start = Instant::now();
583                let inputs = make_prompt_chunk(
584                    0,
585                    vec![&chunk],
586                    &[0],
587                    &load_device,
588                    None,
589                    false,
590                    None,
591                    None,
592                )?;
593                let _ = model.forward(
594                    &inputs.input,
595                    None, // NOTE: We ONLY calibrate the text bits of these models!!
596                    &inputs.positions,
597                    inputs.context_lens,
598                    inputs.position_ids,
599                    model.default_model_specific_args(&inputs.input),
600                    None,
601                    &inputs.flash_meta,
602                )?;
603                match model.cache_mut() {
604                    EitherCache::Full(full) => {
605                        for layer in &mut *full.lock() {
606                            *layer = None
607                        }
608                    }
609                    EitherCache::Normal(normal) => {
610                        for layer in &mut *normal.lock().unwrap().0 {
611                            layer.reset();
612                        }
613                    }
614                }
615                let end = Instant::now();
616                info!(
617                    "Processed chunk {}/{n_chunks} ({chunk_len} tokens), {:.2}s",
618                    i + 1,
619                    end.duration_since(start).as_secs_f32()
620                );
621            }
622            load_device.synchronize()?;
623            let end = Instant::now();
624            info!(
625                "Finished collecting imatrix in {:.2}s",
626                end.duration_since(start).as_secs_f32()
627            );
628        }
629
630        // Only if loading from UQFF
631        if (loading_isq || self.config.topology.is_some()) && self.config.from_uqff.is_none() {
632            let imatrix_source = match (
633                self.config.imatrix.as_ref(),
634                self.config.calibration_file.is_some(),
635            ) {
636                (None, false) => None,
637                (Some(file), false) => Some(ImatrixDataSource::File(file)),
638                (None, true) => Some(ImatrixDataSource::Collected),
639                (Some(_), true) => unreachable!(),
640            };
641            model.quantize(
642                in_situ_quant,
643                device.clone(),
644                self.config.topology.as_ref(),
645                silent,
646                imatrix_source,
647                IsqOrganization::Default,
648                self.config.write_uqff.as_ref(),
649                UqffFullSer {
650                    tokenizer: &tokenizer,
651                    template_filename: paths.get_template_filename(),
652                    generation_config: paths.get_gen_conf_filename(),
653                    config: config.clone(),
654                    processor_filename: paths.get_processor_config(),
655                    preprocessor_filename: paths.get_preprocessor_config(),
656                },
657                Arc::new(MultiProgress::new()),
658            )?;
659        } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
660            model.load_from_artifacts(
661                device.clone(),
662                self.config.topology.as_ref(),
663                silent,
664                from_uqff,
665            )?;
666        }
667
668        let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
669            anyhow::ensure!(
670                !matches!(self.kind, ModelKind::Adapter { .. }),
671                "PagedAttention does not support adapter models."
672            );
673            let cache_config = calculate_cache_config(
674                paged_attn_config.mem_gpu,
675                paged_attn_config.mem_cpu,
676                paged_attn_config.block_size,
677                dtype,
678                model.config(),
679                &device,
680                &layer_devices,
681                silent,
682            )?;
683            let cache_engine =
684                CacheEngine::new(model.config(), &cache_config, dtype, &device, layer_devices)?;
685            (Some(cache_config), Some(cache_engine))
686        } else {
687            (None, None)
688        };
689
690        let max_seq_len = model.max_seq_len();
691        let llg_factory = build_llg_factory(tokenizer.clone())?;
692        let num_hidden_layers = match model.cache() {
693            EitherCache::Full(full) => full.lock().len(),
694            EitherCache::Normal(normal) => normal.lock().unwrap().0.len(),
695        };
696        let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
697        let sliding_window = model.config().sliding_window;
698        let model_metadata = Arc::new(model.config().clone());
699        Ok(Arc::new(Mutex::new(VisionPipeline {
700            model,
701            tokenizer: tokenizer.into(),
702            chat_template: Arc::new(chat_template),
703            model_id: self.model_id.clone(),
704            metadata: Arc::new(GeneralMetadata {
705                max_seq_len,
706                llg_factory: Some(llg_factory),
707                is_xlora: false,
708                num_hidden_layers,
709                eos_tok: eos,
710                kind: self.kind.clone(),
711                no_kv_cache: false,
712                no_prefix_cache: !self.inner.supports_prefix_cacher(&config),
713                activation_dtype: dtype,
714                sliding_window,
715                cache_config,
716                cache_engine,
717                prompt_chunksize: self.config.prompt_chunksize,
718                model_metadata: Some(model_metadata),
719            }),
720            processor,
721            prefixer: self.inner.prefixer(&config),
722            preprocessor_config: Arc::new(preprocessor_config),
723            topology: self.config.topology.clone(),
724            silent,
725            template_filename: paths.get_template_filename().clone(),
726            generation_config: paths.get_gen_conf_filename().cloned(),
727            config,
728            processor_filename: paths.get_processor_config().clone(),
729            preprocessor_filename: paths.get_preprocessor_config().clone(),
730            mapper: pipeline_mapper,
731            imatrix: self.config.imatrix.clone(),
732        })))
733    }
734
735    fn get_id(&self) -> String {
736        self.model_id.to_string()
737    }
738
739    fn get_kind(&self) -> ModelKind {
740        self.kind.clone()
741    }
742}
743
744impl PreProcessingMixin for VisionPipeline {
745    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
746        Some(self.chat_template.clone())
747    }
748    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
749        Some(self.preprocessor_config.clone())
750    }
751    fn get_processor(&self) -> Arc<dyn super::Processor> {
752        self.processor.clone()
753    }
754}
755
756impl IsqPipelineMixin for VisionPipeline {
757    fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
758        let device = self.device().clone();
759        self.model
760            .quantize(
761                Some(dtype),
762                device,
763                self.topology.as_ref(),
764                self.silent,
765                self.imatrix.as_ref().map(ImatrixDataSource::File),
766                IsqOrganization::Default,
767                None,
768                UqffFullSer {
769                    tokenizer: &self.tokenizer,
770                    template_filename: &self.template_filename,
771                    generation_config: self.generation_config.as_ref(),
772                    config: self.config.clone(),
773                    processor_filename: &self.processor_filename,
774                    preprocessor_filename: &self.preprocessor_filename,
775                },
776                Arc::new(MultiProgress::new()),
777            )
778            .map_err(anyhow::Error::msg)
779    }
780}
781
782impl CacheManagerMixin for VisionPipeline {
783    fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
784        if matches!(self.model.cache(), EitherCache::Full(_)) {
785            FullCacheManager.clone_in_cache(self, seqs, false)
786        } else {
787            NormalCacheManager.clone_in_cache(self, seqs, false)
788        }
789    }
790    fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
791        if matches!(self.model.cache(), EitherCache::Full(_)) {
792            FullCacheManager.clone_out_cache(self, seqs, false)
793        } else {
794            NormalCacheManager.clone_out_cache(self, seqs, false)
795        }
796    }
797    fn set_none_cache(
798        &self,
799        seqs: &mut [&mut Sequence],
800        reset_non_granular: bool,
801        modify_draft_cache: bool,
802
803        load_preallocated_cache: bool,
804    ) {
805        if matches!(self.model.cache(), EitherCache::Full(_)) {
806            FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false);
807        } else {
808            NormalCacheManager.set_none_cache(
809                self,
810                seqs,
811                modify_draft_cache,
812                load_preallocated_cache,
813            );
814        }
815        if reset_non_granular {
816            self.reset_non_granular_state()
817        }
818    }
819    fn cache(&self) -> &EitherCache {
820        self.model.cache()
821    }
822}
823
824impl MetadataMixin for VisionPipeline {
825    fn device(&self) -> Device {
826        self.model.device().clone()
827    }
828    fn get_metadata(&self) -> Arc<GeneralMetadata> {
829        self.metadata.clone()
830    }
831    fn name(&self) -> String {
832        self.model_id.clone()
833    }
834    fn reset_non_granular_state(&self) {}
835    fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
836        Some(self.tokenizer.clone())
837    }
838    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
839        Some(&*self.mapper)
840    }
841}
842
843#[async_trait::async_trait]
844impl Pipeline for VisionPipeline {
845    fn forward_inputs(
846        &mut self,
847        inputs: Box<dyn Any>,
848        return_raw_logits: bool,
849    ) -> candle_core::Result<ForwardInputsResult> {
850        let ModelInputs {
851            input_ids,
852            seqlen_offsets,
853            context_lens,
854            position_ids,
855            pixel_values,
856            model_specific_args,
857            paged_attn_meta,
858            flash_meta,
859        } = *inputs.downcast::<ModelInputs>().expect("Downcast failed.");
860        let metadata = self.get_metadata();
861        let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
862            (Some(engine), Some(meta)) => Some((engine.get_kv_cache().clone(), meta)),
863            (Some(_), None) => {
864                // This can happen if Rust-side user code is wrong
865                candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
866            }
867            (None, Some(_)) => {
868                // This should never happen but we handle it anyway
869                candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
870            }
871            (None, None) => None,
872        };
873        let logits = self.model.forward(
874            &input_ids,
875            pixel_values,
876            &seqlen_offsets,
877            context_lens,
878            position_ids,
879            model_specific_args,
880            paged_attn_meta,
881            &flash_meta,
882        )?;
883        if return_raw_logits {
884            Ok(ForwardInputsResult::RawLogits { logits })
885        } else {
886            Ok(ForwardInputsResult::CausalGeneration { logits })
887        }
888    }
889    async fn sample_causal_gen(
890        &self,
891        seqs: &mut [&mut Sequence],
892        logits: Vec<Tensor>,
893        prefix_cacher: &mut PrefixCacheManagerV2,
894        disable_eos_stop: bool,
895        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
896    ) -> Result<(), candle_core::Error> {
897        sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
898    }
899    fn category(&self) -> ModelCategory {
900        ModelCategory::Vision {
901            prefixer: self.prefixer.clone(),
902        }
903    }
904}
905
906impl AnyMoePipelineMixin for VisionPipeline {
907    fn amoe_finish_training(&mut self, gate_model_id: Option<String>) -> candle_core::Result<()> {
908        self.model.finish_training(gate_model_id)
909    }
910    fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
911        self.model.get_vars()
912    }
913    fn amoe_base_model_trainable_params(&self) -> usize {
914        self.model.trainable_params()
915    }
916    fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
917        self.model.take_cached_gating_outputs()
918    }
919    fn amoe_create_layers(
920        &mut self,
921        model_ids: Vec<String>,
922        token: &TokenSource,
923        revision: Option<String>,
924        match_regex: &str,
925        config: crate::amoe::AnyMoeConfig,
926        dtype: candle_core::DType,
927        dev: &Device,
928        (prefix, mlp): (String, String),
929        layers: Vec<usize>,
930        expert_type: AnyMoeExpertType,
931        silent: bool,
932        gate_model_id: Option<String>,
933    ) -> candle_core::Result<()> {
934        let mut vbs = Vec::new();
935        // Precompile regex here
936        let regex = Regex::new(match_regex).map_err(candle_core::Error::msg)?;
937        for model_id in model_ids {
938            let model_id_str = &model_id;
939            let model_id = Path::new(&model_id);
940
941            let api = {
942                let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
943                let mut api = ApiBuilder::from_cache(cache)
944                    .with_progress(!silent)
945                    .with_token(get_token(token).map_err(candle_core::Error::msg)?);
946                if let Ok(x) = std::env::var("HF_HUB_CACHE") {
947                    api = api.with_cache_dir(x.into());
948                }
949                api.build().map_err(candle_core::Error::msg)?
950            };
951            let revision = revision.clone().unwrap_or("main".to_string());
952            let api = api.repo(Repo::with_revision(
953                model_id_str.clone(),
954                RepoType::Model,
955                revision.clone(),
956            ));
957
958            let mut filenames = vec![];
959            for rfilename in api_dir_list!(api, model_id).filter(|x| x.ends_with(".safetensors")) {
960                filenames.push(api_get_file!(api, &rfilename, model_id));
961            }
962
963            let regex = regex.clone();
964            let match_regex_clone = match_regex.to_string();
965            let layers_clone = layers.clone();
966            let vb = from_mmaped_safetensors(
967                filenames,
968                vec![],
969                Some(dtype),
970                dev,
971                vec![None],
972                silent,
973                None,
974                move |key| {
975                    if regex.is_match(&key) {
976                        // Idx of the last char of the layer id, +1
977                        // Assumes N.MLP
978                        let last_layer_idx = key.find(&match_regex_clone).unwrap() - 1;
979                        let first_layer_idx = key[..last_layer_idx].rfind('.').unwrap();
980                        let layer_n = key[first_layer_idx + 1..last_layer_idx]
981                            .parse::<usize>()
982                            .unwrap();
983                        layers_clone.contains(&layer_n) || layers_clone.is_empty()
984                    } else {
985                        false
986                    }
987                },
988                Arc::new(|_| DeviceForLoadTensor::Base),
989            )?;
990            vbs.push(vb);
991        }
992
993        let gate_vb = if let Some(gate_model_id) = gate_model_id {
994            let model_id_str = &gate_model_id;
995            let model_id = Path::new(&gate_model_id);
996
997            let api = {
998                let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
999                let mut api = ApiBuilder::from_cache(cache)
1000                    .with_progress(!silent)
1001                    .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1002                if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1003                    api = api.with_cache_dir(x.into());
1004                }
1005                api.build().map_err(candle_core::Error::msg)?
1006            };
1007            let revision = revision.clone().unwrap_or("main".to_string());
1008            let api = api.repo(Repo::with_revision(
1009                model_id_str.clone(),
1010                RepoType::Model,
1011                revision.clone(),
1012            ));
1013
1014            let mut gate_filenames = vec![];
1015            for rfilename in api_dir_list!(api, model_id).filter(|x| x.ends_with(".safetensors")) {
1016                gate_filenames.push(api_get_file!(api, &rfilename, model_id));
1017            }
1018            assert_eq!(
1019                gate_filenames.len(),
1020                1,
1021                "Gate model ID must contain only one .safetensors file"
1022            );
1023
1024            let vb = from_mmaped_safetensors(
1025                gate_filenames.clone(),
1026                vec![],
1027                Some(dtype),
1028                dev,
1029                vec![None],
1030                silent,
1031                None,
1032                |_| true,
1033                Arc::new(|_| DeviceForLoadTensor::Base),
1034            )?;
1035            info!(
1036                "Loaded gating layers from `{}`",
1037                gate_filenames[0].display()
1038            );
1039            Some(vb)
1040        } else {
1041            None
1042        };
1043
1044        self.model
1045            .create_anymoe_layers(vbs, config, (prefix, mlp), layers, expert_type, gate_vb)
1046    }
1047    fn amoe_supported(&self) -> bool {
1048        self.model.amoe_supported()
1049    }
1050}