mistralrs_core/pipeline/
vision.rs

1use super::cache_manager::{FullCacheManager, NormalCacheManager};
2use super::isq::ImatrixDataSource;
3use super::isq::UqffFullSer;
4use super::{
5    get_model_paths, get_xlora_paths, AdapterKind, AnyMoePipelineMixin, CacheManager,
6    CacheManagerMixin, EitherCache, ForwardInputsResult, Gemma3Loader, GeneralMetadata,
7    IsqPipelineMixin, Loader, MetadataMixin, MiniCpmOLoader, ModelCategory, ModelKind, ModelPaths,
8    Phi4MMLoader, PreProcessingMixin, Processor, Qwen2VLLoader, TokenSource, VLlama4Loader,
9    VLlamaLoader, VisionModel, VisionModelLoader, VisionPromptPrefixer,
10};
11use super::{
12    Idefics2Loader, Idefics3Loader, LLaVALoader, LLaVANextLoader, Mistral3Loader, Phi3VLoader,
13    Qwen2_5VLLoader, VisionLoaderType,
14};
15use crate::device_map::{self, DeviceMapper};
16use crate::distributed::{self, WorkerTransferData};
17use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine};
18use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
19use crate::pipeline::llg::build_llg_factory;
20use crate::pipeline::sampling::sample_and_add_toks;
21use crate::pipeline::text_models_inputs_processor::make_prompt_chunk;
22use crate::pipeline::{get_chat_template, ChatTemplate, IsqOrganization, LocalModelPaths};
23use crate::prefix_cacher::PrefixCacheManagerV2;
24use crate::sequence::Sequence;
25use crate::utils::tokenizer::get_tokenizer;
26use crate::utils::varbuilder_utils::DeviceForLoadTensor;
27use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
28use crate::vision_models::preprocessor_config::PreProcessorConfig;
29use crate::vision_models::processor_config::ProcessorConfig;
30use crate::vision_models::ModelInputs;
31use crate::{
32    api_dir_list, api_get_file, get_paths, get_uqff_paths, vision_normal_model_loader,
33    vision_normal_model_loader_sharded, AnyMoeExpertType, DeviceMapSetting, Ordering,
34    PagedAttentionConfig, Pipeline, Topology, TryIntoDType, GLOBAL_HF_CACHE,
35};
36use anyhow::Result;
37use candle_core::{Device, Tensor, Var};
38use hf_hub::Cache;
39use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
40use indicatif::MultiProgress;
41use mistralrs_quant::{AfqLayer, GgufMatMul, HqqLayer, IsqType, QuantizedSerdeType};
42use rand_isaac::Isaac64Rng;
43use regex_automata::meta::Regex;
44use std::any::Any;
45use std::borrow::Cow;
46use std::num::NonZeroUsize;
47use std::path::{Path, PathBuf};
48use std::str::FromStr;
49use std::sync::{Arc, RwLock};
50use std::time::Instant;
51use std::{env, fs};
52use tokenizers::Tokenizer;
53use tokio::sync::Mutex;
54use tracing::{info, warn};
55
56pub struct VisionPipeline {
57    model: Box<dyn VisionModel + Send + Sync>,
58    tokenizer: Arc<Tokenizer>,
59    chat_template: Arc<ChatTemplate>,
60    model_id: String,
61    metadata: Arc<GeneralMetadata>,
62    processor: Arc<dyn Processor + Send + Sync>,
63    preprocessor_config: Arc<PreProcessorConfig>,
64    topology: Option<Topology>,
65    silent: bool,
66    prefixer: Arc<dyn VisionPromptPrefixer>,
67    mapper: Box<dyn DeviceMapper + Send + Sync>,
68
69    // For full UQFF serialization
70    template_filename: Option<PathBuf>,
71    generation_config: Option<PathBuf>,
72    config: String,
73    processor_filename: Option<PathBuf>,
74    preprocessor_filename: Option<PathBuf>,
75    imatrix: Option<PathBuf>,
76}
77
78/// A loader for a vision (non-quantized) model.
79pub struct VisionLoader {
80    inner: Box<dyn VisionModelLoader>,
81    model_id: String,
82    config: VisionSpecificConfig,
83    kind: ModelKind,
84    chat_template: Option<String>,
85    tokenizer_json: Option<String>,
86    xlora_model_id: Option<String>,
87    xlora_order: Option<Ordering>,
88    token_source: RwLock<Option<TokenSource>>,
89    revision: RwLock<Option<String>>,
90    from_uqff: RwLock<Option<Vec<PathBuf>>>,
91    jinja_explicit: Option<String>,
92    hf_cache_path: Option<PathBuf>,
93    lora_adapter_ids: Option<Vec<String>>,
94}
95
96#[derive(Default)]
97/// A builder for a loader for a vision (non-quantized) model.
98pub struct VisionLoaderBuilder {
99    model_id: Option<String>,
100    config: VisionSpecificConfig,
101    kind: ModelKind,
102    chat_template: Option<String>,
103    tokenizer_json: Option<String>,
104    jinja_explicit: Option<String>,
105    hf_cache_path: Option<PathBuf>,
106    lora_adapter_ids: Option<Vec<String>>,
107}
108
109#[derive(Clone, Default)]
110/// Config specific to loading a vision model.
111pub struct VisionSpecificConfig {
112    pub use_flash_attn: bool,
113    pub prompt_chunksize: Option<NonZeroUsize>,
114    pub topology: Option<Topology>,
115    pub write_uqff: Option<PathBuf>,
116    pub from_uqff: Option<Vec<PathBuf>>,
117    pub max_edge: Option<u32>,
118    pub imatrix: Option<PathBuf>,
119    pub calibration_file: Option<PathBuf>,
120    pub hf_cache_path: Option<PathBuf>,
121}
122
123impl VisionLoaderBuilder {
124    pub fn new(
125        config: VisionSpecificConfig,
126        chat_template: Option<String>,
127        tokenizer_json: Option<String>,
128        model_id: Option<String>,
129        jinja_explicit: Option<String>,
130    ) -> Self {
131        Self {
132            config,
133            chat_template,
134            tokenizer_json,
135            model_id,
136            jinja_explicit,
137            kind: ModelKind::Normal,
138            hf_cache_path: None,
139            ..Default::default()
140        }
141    }
142
143    pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
144        self.hf_cache_path = Some(hf_cache_path);
145        self
146    }
147
148    pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
149        self.kind = ModelKind::Adapter {
150            adapter: AdapterKind::Lora,
151        };
152        self.lora_adapter_ids = Some(lora_adapter_ids);
153        self
154    }
155
156    pub fn build(self, loader: VisionLoaderType) -> Box<dyn Loader> {
157        let loader: Box<dyn VisionModelLoader> = match loader {
158            VisionLoaderType::Phi3V => Box::new(Phi3VLoader),
159            VisionLoaderType::Idefics2 => Box::new(Idefics2Loader),
160            VisionLoaderType::LLaVANext => Box::new(LLaVANextLoader),
161            VisionLoaderType::LLaVA => Box::new(LLaVALoader),
162            VisionLoaderType::VLlama => Box::new(VLlamaLoader),
163            VisionLoaderType::Qwen2VL => Box::new(Qwen2VLLoader),
164            VisionLoaderType::Idefics3 => Box::new(Idefics3Loader),
165            VisionLoaderType::MiniCpmO => Box::new(MiniCpmOLoader),
166            VisionLoaderType::Phi4MM => Box::new(Phi4MMLoader),
167            VisionLoaderType::Qwen2_5VL => Box::new(Qwen2_5VLLoader),
168            VisionLoaderType::Gemma3 => Box::new(Gemma3Loader),
169            VisionLoaderType::Mistral3 => Box::new(Mistral3Loader),
170            VisionLoaderType::Llama4 => Box::new(VLlama4Loader),
171        };
172        Box::new(VisionLoader {
173            inner: loader,
174            model_id: self.model_id.unwrap(),
175            config: self.config,
176            kind: self.kind,
177            chat_template: self.chat_template,
178            tokenizer_json: self.tokenizer_json,
179            xlora_model_id: None,
180            xlora_order: None,
181            jinja_explicit: self.jinja_explicit,
182            token_source: RwLock::new(None),
183            revision: RwLock::new(None),
184            from_uqff: RwLock::new(None),
185            hf_cache_path: self.hf_cache_path,
186            lora_adapter_ids: self.lora_adapter_ids,
187        })
188    }
189}
190
191impl Loader for VisionLoader {
192    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
193    fn load_model_from_hf(
194        &self,
195        revision: Option<String>,
196        token_source: TokenSource,
197        dtype: &dyn TryIntoDType,
198        device: &Device,
199        silent: bool,
200        mapper: DeviceMapSetting,
201        in_situ_quant: Option<IsqType>,
202        paged_attn_config: Option<PagedAttentionConfig>,
203    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
204        let cache = self
205            .hf_cache_path
206            .clone()
207            .map(Cache::new)
208            .unwrap_or_default();
209        GLOBAL_HF_CACHE.get_or_init(|| cache);
210
211        let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths!(
212            LocalModelPaths,
213            &token_source,
214            revision.clone(),
215            self,
216            None,
217            None,
218            silent,
219            self.config.from_uqff.is_some()
220        );
221        if let Some(from_uqff) = self.config.from_uqff.clone() {
222            *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
223        }
224        *self
225            .token_source
226            .write()
227            .expect("Failed to write to token source") = Some(token_source);
228        *self.revision.write().expect("Failed to write to revision") = revision;
229        self.load_model_from_path(
230            &paths?,
231            dtype,
232            device,
233            silent,
234            mapper,
235            in_situ_quant,
236            paged_attn_config,
237        )
238    }
239
240    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
241    fn load_model_from_path(
242        &self,
243        paths: &Box<dyn ModelPaths>,
244        dtype: &dyn TryIntoDType,
245        device: &Device,
246        silent: bool,
247        mut mapper: DeviceMapSetting,
248        in_situ_quant: Option<IsqType>,
249        mut paged_attn_config: Option<PagedAttentionConfig>,
250    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
251        let config = std::fs::read_to_string(paths.get_config_filename())?;
252
253        if !self.inner.supports_paged_attention() {
254            paged_attn_config = None;
255        }
256
257        let use_nccl = mistralrs_quant::distributed::use_nccl();
258
259        let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
260            let payload: WorkerTransferData = serde_json::from_str(&payload)?;
261            let WorkerTransferData::Init { id: _, worker_rank } = payload;
262            vec![candle_core::Device::new_cuda_with_stream(worker_rank + 1)?]
263        } else if use_nccl {
264            vec![candle_core::Device::new_cuda_with_stream(0)?]
265        } else {
266            device_map::get_all_similar_devices(device)?
267        };
268        let device = if use_nccl {
269            available_devices[0].clone()
270        } else {
271            device.clone()
272        };
273
274        // If auto, convert to Map if not using nccl
275        if use_nccl {
276            mapper = DeviceMapSetting::DummyNccl {
277                nm_device: available_devices[0].clone(),
278            };
279        } else if let DeviceMapSetting::Auto(params) = mapper.clone() {
280            // Initial dtype
281            let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
282
283            // ISQ or UQFF: quantized path
284            // Match logic below where UQFF has priority
285            let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
286                if let Some(serialized) = &*self.from_uqff.read().unwrap() {
287                    let weight_pack_factor = {
288                        let ser_artifacts = unsafe {
289                            candle_core::safetensors::MmapedSafetensors::multi(serialized)?
290                        };
291                        let mut total_pack_factors = 0;
292                        let total_tensors = ser_artifacts.tensors().len();
293                        for (_, artifact) in ser_artifacts.tensors() {
294                            let artifact = artifact.data();
295                            // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
296                            let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
297                            let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
298                            {
299                                QuantizedSerdeType::Hqq => {
300                                    HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
301                                        .pack_factor(dtype)
302                                }
303                                QuantizedSerdeType::Gguf => {
304                                    GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
305                                        .pack_factor(dtype)
306                                }
307                                QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
308                                QuantizedSerdeType::Unquant => 1,
309                                QuantizedSerdeType::Afq => {
310                                    AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
311                                        .pack_factor(dtype)
312                                }
313                            };
314                            total_pack_factors += pack_factor;
315                        }
316
317                        total_pack_factors / total_tensors
318                    };
319
320                    let layer_sizes_in_bytes =
321                        self.inner
322                            .layer_sizes_in_bytes(&config, dtype, weight_pack_factor)?;
323                    let non_mapped_size_in_bytes =
324                        self.inner
325                            .non_mapped_size_in_bytes(&config, dtype, weight_pack_factor)?;
326                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
327                    (
328                        layer_sizes_in_bytes,
329                        non_mapped_size_in_bytes,
330                        layer_sizes_sum + non_mapped_size_in_bytes,
331                    )
332                } else if let Some(isq) = in_situ_quant {
333                    let weight_pack_factor = isq.pack_factor(dtype);
334                    let layer_sizes_in_bytes =
335                        self.inner
336                            .layer_sizes_in_bytes(&config, dtype, weight_pack_factor)?;
337                    let non_mapped_size_in_bytes =
338                        self.inner
339                            .non_mapped_size_in_bytes(&config, dtype, weight_pack_factor)?;
340                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
341                    (
342                        layer_sizes_in_bytes,
343                        non_mapped_size_in_bytes,
344                        layer_sizes_sum + non_mapped_size_in_bytes,
345                    )
346                } else {
347                    let layer_sizes_in_bytes =
348                        self.inner.layer_sizes_in_bytes(&config, dtype, 1)?;
349                    let non_mapped_size_in_bytes =
350                        self.inner.non_mapped_size_in_bytes(&config, dtype, 1)?;
351                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
352                    (
353                        layer_sizes_in_bytes,
354                        non_mapped_size_in_bytes,
355                        layer_sizes_sum + non_mapped_size_in_bytes,
356                    )
357                };
358
359            let new = self.inner.get_device_layers(
360                &config,
361                self.inner.num_layers(&config)?,
362                layer_sizes_in_bytes,
363                non_mapped_size_in_bytes,
364                total_model_size_in_bytes,
365                &available_devices,
366                dtype,
367                &params,
368                params.max_seq_len(),
369                paged_attn_config.as_ref(),
370            )?;
371            mapper = DeviceMapSetting::Map(new);
372        }
373
374        let pipeline_mapper = mapper.into_mapper(
375            self.inner.num_layers(&config)?,
376            &device,
377            self.config.topology.as_ref(),
378        )?;
379        let mapper = mapper.into_mapper(
380            self.inner.num_layers(&config)?,
381            &device,
382            self.config.topology.as_ref(),
383        )?;
384        let mut layer_devices = Vec::new();
385        for layer in 0..self.inner.num_layers(&config)? {
386            let device = mapper.device_for(layer, false).cloned();
387            layer_devices.push(device);
388        }
389        let dtype = mapper.get_min_dtype(dtype)?;
390
391        // TODO: PagedAttention is not supported with CPU for now.
392        // This check is not really necessary because `get_device_layers` should prevent it.
393        let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
394        if mapping_uses_cpu {
395            warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
396            paged_attn_config = None;
397        }
398
399        info!(
400            "Model config: {:?}",
401            self.inner
402                .get_config_repr(&config, self.config.use_flash_attn)?
403        );
404
405        let mut loading_isq = in_situ_quant.is_some() || self.config.from_uqff.is_some();
406        if let Some(ref topology) = self.config.topology {
407            loading_isq |= topology
408                .0
409                .iter()
410                .any(|layer| layer.as_ref().is_some_and(|layer| layer.isq.is_some()));
411        }
412
413        if self.config.imatrix.is_some() && self.config.calibration_file.is_some() {
414            anyhow::bail!(
415                "`imatrix` and `calibration_file` were both specified, this is not allowed."
416            );
417        }
418
419        // Load onto the regular device if not using isq or if the calibration file is specified
420        let load_device = if !loading_isq || self.config.calibration_file.is_some() {
421            loading_isq = false;
422            device.clone()
423        } else {
424            Device::Cpu
425        };
426
427        let attention_mechanism = if paged_attn_config.is_some() {
428            AttentionImplementation::PagedAttention
429        } else {
430            AttentionImplementation::Eager
431        };
432
433        let multi_progress = Arc::new(MultiProgress::new());
434
435        let mut model = if use_nccl {
436            let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
437                dtype,
438                &device,
439                &available_devices,
440                silent,
441                &config,
442                loading_isq,
443                self.config.from_uqff.is_some(),
444                IsqOrganization::Default,
445                &*self.inner,
446                paths.as_ref(),
447            )?;
448
449            // Special case for where things can be more optimially loaded.
450            match self.kind {
451                ModelKind::Normal => vision_normal_model_loader_sharded!(
452                    sharded_vb,
453                    config,
454                    self.inner,
455                    self.config.use_flash_attn,
456                    mapper,
457                    loading_isq,
458                    device.clone(),
459                    attention_mechanism,
460                    multi_progress.clone(),
461                ),
462                _ => unreachable!(),
463            }
464        } else {
465            match self.kind {
466                ModelKind::Normal => vision_normal_model_loader!(
467                    paths,
468                    Some(dtype),
469                    &load_device,
470                    layer_devices.clone(),
471                    config,
472                    self.inner,
473                    self.config.use_flash_attn,
474                    silent,
475                    mapper,
476                    loading_isq,
477                    self.config.from_uqff.is_some(),
478                    device.clone(),
479                    attention_mechanism,
480                    multi_progress,
481                ),
482                _ => unreachable!(),
483            }
484        };
485
486        // Handle the Gemma 3 1b case here
487        let preprocessor_config: PreProcessorConfig = match paths.get_preprocessor_config().as_ref()
488        {
489            Some(preprocessor_config) => {
490                serde_json::from_str(&fs::read_to_string(preprocessor_config).unwrap()).unwrap()
491            }
492            None => PreProcessorConfig::default(),
493        };
494        let processor_config: Option<ProcessorConfig> = paths
495            .get_processor_config()
496            .as_ref()
497            .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
498
499        let processor = self.inner.get_processor(
500            &config,
501            processor_config,
502            preprocessor_config.clone(),
503            self.config.max_edge,
504        ); //There are always some repos that don't properly handle config position, for example... LLaVA
505
506        let tokenizer = get_tokenizer(
507            paths.get_tokenizer_filename(),
508            Some(processor.get_special_tokens()),
509        )?;
510
511        let gen_conf: Option<GenerationConfig> = paths.get_gen_conf_filename().map(|f| {
512            serde_json::from_str(&fs::read_to_string(f).unwrap())
513                .expect("bos_token_id/eos_token_id missing in generation_config.json")
514        });
515        let chat_template = get_chat_template(
516            paths,
517            &self.jinja_explicit,
518            &paths
519                .get_chat_template_explicit()
520                .as_ref()
521                .map(|x| x.to_string_lossy().to_string())
522                .clone(),
523            &self.chat_template,
524            None,
525        );
526
527        if let Some(calibration_file) = &self.config.calibration_file {
528            let calibration_data = std::fs::read_to_string(calibration_file)?;
529            // Tokenize, don't add bos yet
530            let tokens = tokenizer
531                .encode_fast(calibration_data, false)
532                .map_err(anyhow::Error::msg)?
533                .get_ids()
534                .to_vec();
535            info!(
536                "Collecting imatrix from calibration file `{}` of {} tokens.",
537                calibration_file.display(),
538                tokens.len()
539            );
540            let bos_toks = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();
541            let bos_tok_id = tokenizer
542                .token_to_id(&bos_toks[0])
543                .expect("Somehow the bos token is not present.");
544
545            // NOTE: We ONLY calibrate the text bits of these models!!
546            // So only those should be tracked!
547            model.begin_track_stats()?;
548
549            const CHUNK_SIZE: usize = 1024;
550            let n_chunks: usize = tokens.len().div_ceil(CHUNK_SIZE);
551            let start = Instant::now();
552            for (i, chunk) in tokens.chunks(CHUNK_SIZE).enumerate() {
553                let chunk = [vec![bos_tok_id], chunk.to_vec()].concat();
554                let chunk_len = chunk.len();
555
556                let start = Instant::now();
557                let inputs = make_prompt_chunk(
558                    0,
559                    vec![&chunk],
560                    &[0],
561                    &load_device,
562                    None,
563                    false,
564                    None,
565                    None,
566                )?;
567                let _ = model.forward(
568                    &inputs.input,
569                    None, // NOTE: We ONLY calibrate the text bits of these models!!
570                    &inputs.positions,
571                    inputs.context_lens,
572                    inputs.position_ids,
573                    model.default_model_specific_args(&inputs.input),
574                    None,
575                    &inputs.flash_meta,
576                )?;
577                match model.cache_mut() {
578                    EitherCache::Full(full) => {
579                        for layer in &mut *full.lock() {
580                            *layer = None
581                        }
582                    }
583                    EitherCache::Normal(normal) => {
584                        for layer in &mut *normal.lock().unwrap().0 {
585                            layer.reset();
586                        }
587                    }
588                }
589                let end = Instant::now();
590                info!(
591                    "Processed chunk {}/{n_chunks} ({chunk_len} tokens), {:.2}s",
592                    i + 1,
593                    end.duration_since(start).as_secs_f32()
594                );
595            }
596            load_device.synchronize()?;
597            let end = Instant::now();
598            info!(
599                "Finished collecting imatrix in {:.2}s",
600                end.duration_since(start).as_secs_f32()
601            );
602        }
603
604        if (in_situ_quant.is_some() || self.config.topology.is_some())
605            && self.config.from_uqff.is_none()
606        {
607            let imatrix_source = match (
608                self.config.imatrix.as_ref(),
609                self.config.calibration_file.is_some(),
610            ) {
611                (None, false) => None,
612                (Some(file), false) => Some(ImatrixDataSource::File(file)),
613                (None, true) => Some(ImatrixDataSource::Collected),
614                (Some(_), true) => unreachable!(),
615            };
616            model.quantize(
617                in_situ_quant,
618                device.clone(),
619                self.config.topology.as_ref(),
620                silent,
621                imatrix_source,
622                IsqOrganization::Default,
623                self.config.write_uqff.as_ref(),
624                UqffFullSer {
625                    tokenizer: &tokenizer,
626                    template_filename: paths.get_template_filename(),
627                    generation_config: paths.get_gen_conf_filename(),
628                    config: config.clone(),
629                    processor_filename: paths.get_processor_config(),
630                    preprocessor_filename: paths.get_preprocessor_config(),
631                },
632                Arc::new(MultiProgress::new()),
633            )?;
634        } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
635            model.load_from_artifacts(
636                device.clone(),
637                self.config.topology.as_ref(),
638                silent,
639                from_uqff,
640            )?;
641        }
642
643        let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
644            anyhow::ensure!(
645                !matches!(self.kind, ModelKind::Adapter { .. }),
646                "PagedAttention does not support adapter models."
647            );
648            let cache_config = calculate_cache_config(
649                paged_attn_config.mem_gpu,
650                paged_attn_config.mem_cpu,
651                paged_attn_config.block_size,
652                dtype,
653                model.config(),
654                &device,
655                &layer_devices,
656                silent,
657            )?;
658            let cache_engine =
659                CacheEngine::new(model.config(), &cache_config, dtype, &device, layer_devices)?;
660            (Some(cache_config), Some(cache_engine))
661        } else {
662            (None, None)
663        };
664
665        let max_seq_len = model.max_seq_len();
666        let llg_factory = build_llg_factory(tokenizer.clone())?;
667        let num_hidden_layers = match model.cache() {
668            EitherCache::Full(full) => full.lock().len(),
669            EitherCache::Normal(normal) => normal.lock().unwrap().0.len(),
670        };
671        let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
672        let sliding_window = model.config().sliding_window;
673        let model_metadata = Arc::new(model.config().clone());
674        Ok(Arc::new(Mutex::new(VisionPipeline {
675            model,
676            tokenizer: tokenizer.into(),
677            chat_template: Arc::new(chat_template),
678            model_id: self.model_id.clone(),
679            metadata: Arc::new(GeneralMetadata {
680                max_seq_len,
681                llg_factory: Some(llg_factory),
682                is_xlora: false,
683                num_hidden_layers,
684                eos_tok: eos,
685                kind: self.kind.clone(),
686                no_kv_cache: false,
687                no_prefix_cache: false,
688                activation_dtype: dtype,
689                sliding_window,
690                cache_config,
691                cache_engine,
692                prompt_chunksize: self.config.prompt_chunksize,
693                model_metadata: Some(model_metadata),
694            }),
695            processor,
696            prefixer: self.inner.prefixer(),
697            preprocessor_config: Arc::new(preprocessor_config),
698            topology: self.config.topology.clone(),
699            silent,
700            template_filename: paths.get_template_filename().clone(),
701            generation_config: paths.get_gen_conf_filename().cloned(),
702            config,
703            processor_filename: paths.get_processor_config().clone(),
704            preprocessor_filename: paths.get_preprocessor_config().clone(),
705            mapper: pipeline_mapper,
706            imatrix: self.config.imatrix.clone(),
707        })))
708    }
709
710    fn get_id(&self) -> String {
711        self.model_id.to_string()
712    }
713
714    fn get_kind(&self) -> ModelKind {
715        self.kind.clone()
716    }
717}
718
719impl PreProcessingMixin for VisionPipeline {
720    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
721        Some(self.chat_template.clone())
722    }
723    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
724        Some(self.preprocessor_config.clone())
725    }
726    fn get_processor(&self) -> Arc<dyn super::Processor> {
727        self.processor.clone()
728    }
729}
730
731impl IsqPipelineMixin for VisionPipeline {
732    fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
733        let device = self.device().clone();
734        self.model
735            .quantize(
736                Some(dtype),
737                device,
738                self.topology.as_ref(),
739                self.silent,
740                self.imatrix.as_ref().map(ImatrixDataSource::File),
741                IsqOrganization::Default,
742                None,
743                UqffFullSer {
744                    tokenizer: &self.tokenizer,
745                    template_filename: &self.template_filename,
746                    generation_config: self.generation_config.as_ref(),
747                    config: self.config.clone(),
748                    processor_filename: &self.processor_filename,
749                    preprocessor_filename: &self.preprocessor_filename,
750                },
751                Arc::new(MultiProgress::new()),
752            )
753            .map_err(anyhow::Error::msg)
754    }
755}
756
757impl CacheManagerMixin for VisionPipeline {
758    fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
759        if matches!(self.model.cache(), EitherCache::Full(_)) {
760            FullCacheManager.clone_in_cache(self, seqs, false)
761        } else {
762            NormalCacheManager.clone_in_cache(self, seqs, false)
763        }
764    }
765    fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
766        if matches!(self.model.cache(), EitherCache::Full(_)) {
767            FullCacheManager.clone_out_cache(self, seqs, false)
768        } else {
769            NormalCacheManager.clone_out_cache(self, seqs, false)
770        }
771    }
772    fn set_none_cache(
773        &self,
774        seqs: &mut [&mut Sequence],
775        reset_non_granular: bool,
776        modify_draft_cache: bool,
777
778        load_preallocated_cache: bool,
779    ) {
780        if matches!(self.model.cache(), EitherCache::Full(_)) {
781            FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false);
782        } else {
783            NormalCacheManager.set_none_cache(
784                self,
785                seqs,
786                modify_draft_cache,
787                load_preallocated_cache,
788            );
789        }
790        if reset_non_granular {
791            self.reset_non_granular_state()
792        }
793    }
794    fn cache(&self) -> &EitherCache {
795        self.model.cache()
796    }
797}
798
799impl MetadataMixin for VisionPipeline {
800    fn device(&self) -> Device {
801        self.model.device().clone()
802    }
803    fn get_metadata(&self) -> Arc<GeneralMetadata> {
804        self.metadata.clone()
805    }
806    fn name(&self) -> String {
807        self.model_id.clone()
808    }
809    fn reset_non_granular_state(&self) {}
810    fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
811        Some(self.tokenizer.clone())
812    }
813    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
814        Some(&*self.mapper)
815    }
816}
817
818#[async_trait::async_trait]
819impl Pipeline for VisionPipeline {
820    fn forward_inputs(
821        &mut self,
822        inputs: Box<dyn Any>,
823        return_raw_logits: bool,
824    ) -> candle_core::Result<ForwardInputsResult> {
825        let ModelInputs {
826            input_ids,
827            seqlen_offsets,
828            context_lens,
829            position_ids,
830            pixel_values,
831            model_specific_args,
832            paged_attn_meta,
833            flash_meta,
834        } = *inputs.downcast::<ModelInputs>().expect("Downcast failed.");
835        let metadata = self.get_metadata();
836        let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
837            (Some(engine), Some(meta)) => Some((engine.get_kv_cache().clone(), meta)),
838            (Some(_), None) => {
839                // This can happen if Rust-side user code is wrong
840                candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
841            }
842            (None, Some(_)) => {
843                // This should never happen but we handle it anyway
844                candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
845            }
846            (None, None) => None,
847        };
848        let logits = self.model.forward(
849            &input_ids,
850            pixel_values,
851            &seqlen_offsets,
852            context_lens,
853            position_ids,
854            model_specific_args,
855            paged_attn_meta,
856            &flash_meta,
857        )?;
858        if return_raw_logits {
859            Ok(ForwardInputsResult::RawLogits { logits })
860        } else {
861            Ok(ForwardInputsResult::CausalGeneration { logits })
862        }
863    }
864    async fn sample_causal_gen(
865        &self,
866        seqs: &mut [&mut Sequence],
867        logits: Vec<Tensor>,
868        prefix_cacher: &mut PrefixCacheManagerV2,
869        disable_eos_stop: bool,
870        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
871    ) -> Result<(), candle_core::Error> {
872        sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
873    }
874    fn category(&self) -> ModelCategory {
875        let has_conv2d = self.model.has_conv2d();
876        ModelCategory::Vision {
877            has_conv2d,
878            prefixer: self.prefixer.clone(),
879        }
880    }
881}
882
883impl AnyMoePipelineMixin for VisionPipeline {
884    fn amoe_finish_training(&mut self, gate_model_id: Option<String>) -> candle_core::Result<()> {
885        self.model.finish_training(gate_model_id)
886    }
887    fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
888        self.model.get_vars()
889    }
890    fn amoe_base_model_trainable_params(&self) -> usize {
891        self.model.trainable_params()
892    }
893    fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
894        self.model.take_cached_gating_outputs()
895    }
896    fn amoe_create_layers(
897        &mut self,
898        model_ids: Vec<String>,
899        token: &TokenSource,
900        revision: Option<String>,
901        match_regex: &str,
902        config: crate::amoe::AnyMoeConfig,
903        dtype: candle_core::DType,
904        dev: &Device,
905        (prefix, mlp): (String, String),
906        layers: Vec<usize>,
907        expert_type: AnyMoeExpertType,
908        silent: bool,
909        gate_model_id: Option<String>,
910    ) -> candle_core::Result<()> {
911        let mut vbs = Vec::new();
912        // Precompile regex here
913        let regex = Regex::new(match_regex).map_err(candle_core::Error::msg)?;
914        for model_id in model_ids {
915            let model_id_str = &model_id;
916            let model_id = Path::new(&model_id);
917
918            let api = {
919                let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
920                let mut api = ApiBuilder::from_cache(cache)
921                    .with_progress(!silent)
922                    .with_token(get_token(token).map_err(candle_core::Error::msg)?);
923                if let Ok(x) = std::env::var("HF_HUB_CACHE") {
924                    api = api.with_cache_dir(x.into());
925                }
926                api.build().map_err(candle_core::Error::msg)?
927            };
928            let revision = revision.clone().unwrap_or("main".to_string());
929            let api = api.repo(Repo::with_revision(
930                model_id_str.clone(),
931                RepoType::Model,
932                revision.clone(),
933            ));
934
935            let mut filenames = vec![];
936            for rfilename in api_dir_list!(api, model_id).filter(|x| x.ends_with(".safetensors")) {
937                filenames.push(api_get_file!(api, &rfilename, model_id));
938            }
939
940            let regex = regex.clone();
941            let match_regex_clone = match_regex.to_string();
942            let layers_clone = layers.clone();
943            let vb = from_mmaped_safetensors(
944                filenames,
945                vec![],
946                Some(dtype),
947                dev,
948                vec![None],
949                silent,
950                None,
951                move |key| {
952                    if regex.is_match(&key) {
953                        // Idx of the last char of the layer id, +1
954                        // Assumes N.MLP
955                        let last_layer_idx = key.find(&match_regex_clone).unwrap() - 1;
956                        let first_layer_idx = key[..last_layer_idx].rfind('.').unwrap();
957                        let layer_n = key[first_layer_idx + 1..last_layer_idx]
958                            .parse::<usize>()
959                            .unwrap();
960                        layers_clone.contains(&layer_n) || layers_clone.is_empty()
961                    } else {
962                        false
963                    }
964                },
965                Arc::new(|_| DeviceForLoadTensor::Base),
966            )?;
967            vbs.push(vb);
968        }
969
970        let gate_vb = if let Some(gate_model_id) = gate_model_id {
971            let model_id_str = &gate_model_id;
972            let model_id = Path::new(&gate_model_id);
973
974            let api = {
975                let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
976                let mut api = ApiBuilder::from_cache(cache)
977                    .with_progress(!silent)
978                    .with_token(get_token(token).map_err(candle_core::Error::msg)?);
979                if let Ok(x) = std::env::var("HF_HUB_CACHE") {
980                    api = api.with_cache_dir(x.into());
981                }
982                api.build().map_err(candle_core::Error::msg)?
983            };
984            let revision = revision.clone().unwrap_or("main".to_string());
985            let api = api.repo(Repo::with_revision(
986                model_id_str.clone(),
987                RepoType::Model,
988                revision.clone(),
989            ));
990
991            let mut gate_filenames = vec![];
992            for rfilename in api_dir_list!(api, model_id).filter(|x| x.ends_with(".safetensors")) {
993                gate_filenames.push(api_get_file!(api, &rfilename, model_id));
994            }
995            assert_eq!(
996                gate_filenames.len(),
997                1,
998                "Gate model ID must contain only one .safetensors file"
999            );
1000
1001            let vb = from_mmaped_safetensors(
1002                gate_filenames.clone(),
1003                vec![],
1004                Some(dtype),
1005                dev,
1006                vec![None],
1007                silent,
1008                None,
1009                |_| true,
1010                Arc::new(|_| DeviceForLoadTensor::Base),
1011            )?;
1012            info!(
1013                "Loaded gating layers from `{}`",
1014                gate_filenames[0].display()
1015            );
1016            Some(vb)
1017        } else {
1018            None
1019        };
1020
1021        self.model
1022            .create_anymoe_layers(vbs, config, (prefix, mlp), layers, expert_type, gate_vb)
1023    }
1024    fn amoe_supported(&self) -> bool {
1025        self.model.amoe_supported()
1026    }
1027}