mistralrs_core/pipeline/
vision.rs

1use super::cache_manager::{FullCacheManager, NormalCacheManager};
2use super::isq::ImatrixDataSource;
3use super::isq::UqffFullSer;
4use super::{
5    get_model_paths, get_xlora_paths, AdapterKind, AnyMoePipelineMixin, CacheManager,
6    CacheManagerMixin, EitherCache, ForwardInputsResult, Gemma3Loader, GeneralMetadata,
7    IsqPipelineMixin, Loader, MetadataMixin, MiniCpmOLoader, ModelCategory, ModelKind, ModelPaths,
8    Phi4MMLoader, PreProcessingMixin, Processor, Qwen2VLLoader, TokenSource, VLlamaLoader,
9    VisionModel, VisionModelLoader, VisionPromptPrefixer,
10};
11use super::{
12    Idefics2Loader, Idefics3Loader, LLaVALoader, LLaVANextLoader, Mistral3Loader, Phi3VLoader,
13    Qwen2_5VLLoader, VisionLoaderType,
14};
15use crate::device_map::{self, DeviceMapper};
16use crate::distributed::{self, WorkerTransferData};
17use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine};
18use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
19use crate::pipeline::llg::build_tok_env;
20use crate::pipeline::sampling::sample_and_add_toks;
21use crate::pipeline::text_models_inputs_processor::make_prompt_chunk;
22use crate::pipeline::{get_chat_template, ChatTemplate, IsqOrganization, LocalModelPaths};
23use crate::prefix_cacher::PrefixCacheManagerV2;
24use crate::sequence::Sequence;
25use crate::utils::tokenizer::get_tokenizer;
26use crate::utils::varbuilder_utils::DeviceForLoadTensor;
27use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
28use crate::vision_models::preprocessor_config::PreProcessorConfig;
29use crate::vision_models::processor_config::ProcessorConfig;
30use crate::vision_models::ModelInputs;
31use crate::{
32    api_dir_list, api_get_file, get_paths, get_uqff_paths, vision_normal_model_loader,
33    vision_normal_model_loader_sharded, AnyMoeExpertType, DeviceMapSetting, Ordering,
34    PagedAttentionConfig, Pipeline, Topology, TryIntoDType, GLOBAL_HF_CACHE,
35};
36use anyhow::Result;
37use candle_core::{Device, Tensor, Var};
38use hf_hub::Cache;
39use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
40use indicatif::MultiProgress;
41use mistralrs_quant::{AfqLayer, GgufMatMul, HqqLayer, IsqType, QuantizedSerdeType};
42use rand_isaac::Isaac64Rng;
43use regex_automata::meta::Regex;
44use std::any::Any;
45use std::borrow::Cow;
46use std::num::NonZeroUsize;
47use std::path::{Path, PathBuf};
48use std::str::FromStr;
49use std::sync::{Arc, RwLock};
50use std::time::Instant;
51use std::{env, fs};
52use tokenizers::Tokenizer;
53use tokio::sync::Mutex;
54use tracing::{info, warn};
55
56pub struct VisionPipeline {
57    model: Box<dyn VisionModel + Send + Sync>,
58    tokenizer: Arc<Tokenizer>,
59    chat_template: Arc<ChatTemplate>,
60    model_id: String,
61    metadata: Arc<GeneralMetadata>,
62    processor: Arc<dyn Processor + Send + Sync>,
63    preprocessor_config: Arc<PreProcessorConfig>,
64    topology: Option<Topology>,
65    silent: bool,
66    prefixer: Arc<dyn VisionPromptPrefixer>,
67    mapper: Box<dyn DeviceMapper + Send + Sync>,
68
69    // For full UQFF serialization
70    template_filename: Option<PathBuf>,
71    generation_config: Option<PathBuf>,
72    config: String,
73    processor_filename: Option<PathBuf>,
74    preprocessor_filename: Option<PathBuf>,
75    imatrix: Option<PathBuf>,
76}
77
78/// A loader for a vision (non-quantized) model.
79pub struct VisionLoader {
80    inner: Box<dyn VisionModelLoader>,
81    model_id: String,
82    config: VisionSpecificConfig,
83    kind: ModelKind,
84    chat_template: Option<String>,
85    tokenizer_json: Option<String>,
86    xlora_model_id: Option<String>,
87    xlora_order: Option<Ordering>,
88    token_source: RwLock<Option<TokenSource>>,
89    revision: RwLock<Option<String>>,
90    from_uqff: RwLock<Option<PathBuf>>,
91    jinja_explicit: Option<String>,
92    hf_cache_path: Option<PathBuf>,
93    lora_adapter_ids: Option<Vec<String>>,
94}
95
96#[derive(Default)]
97/// A builder for a loader for a vision (non-quantized) model.
98pub struct VisionLoaderBuilder {
99    model_id: Option<String>,
100    config: VisionSpecificConfig,
101    kind: ModelKind,
102    chat_template: Option<String>,
103    tokenizer_json: Option<String>,
104    jinja_explicit: Option<String>,
105    hf_cache_path: Option<PathBuf>,
106    lora_adapter_ids: Option<Vec<String>>,
107}
108
109#[derive(Clone, Default)]
110/// Config specific to loading a vision model.
111pub struct VisionSpecificConfig {
112    pub use_flash_attn: bool,
113    pub prompt_chunksize: Option<NonZeroUsize>,
114    pub topology: Option<Topology>,
115    pub write_uqff: Option<PathBuf>,
116    pub from_uqff: Option<PathBuf>,
117    pub max_edge: Option<u32>,
118    pub imatrix: Option<PathBuf>,
119    pub calibration_file: Option<PathBuf>,
120    pub hf_cache_path: Option<PathBuf>,
121}
122
123impl VisionLoaderBuilder {
124    pub fn new(
125        config: VisionSpecificConfig,
126        chat_template: Option<String>,
127        tokenizer_json: Option<String>,
128        model_id: Option<String>,
129        jinja_explicit: Option<String>,
130    ) -> Self {
131        Self {
132            config,
133            chat_template,
134            tokenizer_json,
135            model_id,
136            jinja_explicit,
137            kind: ModelKind::Normal,
138            hf_cache_path: None,
139            ..Default::default()
140        }
141    }
142
143    pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
144        self.hf_cache_path = Some(hf_cache_path);
145        self
146    }
147
148    pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
149        self.kind = ModelKind::Adapter {
150            adapter: AdapterKind::Lora,
151        };
152        self.lora_adapter_ids = Some(lora_adapter_ids);
153        self
154    }
155
156    pub fn build(self, loader: VisionLoaderType) -> Box<dyn Loader> {
157        let loader: Box<dyn VisionModelLoader> = match loader {
158            VisionLoaderType::Phi3V => Box::new(Phi3VLoader),
159            VisionLoaderType::Idefics2 => Box::new(Idefics2Loader),
160            VisionLoaderType::LLaVANext => Box::new(LLaVANextLoader),
161            VisionLoaderType::LLaVA => Box::new(LLaVALoader),
162            VisionLoaderType::VLlama => Box::new(VLlamaLoader),
163            VisionLoaderType::Qwen2VL => Box::new(Qwen2VLLoader),
164            VisionLoaderType::Idefics3 => Box::new(Idefics3Loader),
165            VisionLoaderType::MiniCpmO => Box::new(MiniCpmOLoader),
166            VisionLoaderType::Phi4MM => Box::new(Phi4MMLoader),
167            VisionLoaderType::Qwen2_5VL => Box::new(Qwen2_5VLLoader),
168            VisionLoaderType::Gemma3 => Box::new(Gemma3Loader),
169            VisionLoaderType::Mistral3 => Box::new(Mistral3Loader),
170        };
171        Box::new(VisionLoader {
172            inner: loader,
173            model_id: self.model_id.unwrap(),
174            config: self.config,
175            kind: self.kind,
176            chat_template: self.chat_template,
177            tokenizer_json: self.tokenizer_json,
178            xlora_model_id: None,
179            xlora_order: None,
180            jinja_explicit: self.jinja_explicit,
181            token_source: RwLock::new(None),
182            revision: RwLock::new(None),
183            from_uqff: RwLock::new(None),
184            hf_cache_path: self.hf_cache_path,
185            lora_adapter_ids: self.lora_adapter_ids,
186        })
187    }
188}
189
190impl Loader for VisionLoader {
191    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
192    fn load_model_from_hf(
193        &self,
194        revision: Option<String>,
195        token_source: TokenSource,
196        dtype: &dyn TryIntoDType,
197        device: &Device,
198        silent: bool,
199        mapper: DeviceMapSetting,
200        in_situ_quant: Option<IsqType>,
201        paged_attn_config: Option<PagedAttentionConfig>,
202    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
203        let cache = self
204            .hf_cache_path
205            .clone()
206            .map(Cache::new)
207            .unwrap_or_default();
208        GLOBAL_HF_CACHE.get_or_init(|| cache);
209
210        let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths!(
211            LocalModelPaths,
212            &token_source,
213            revision.clone(),
214            self,
215            None,
216            None,
217            silent,
218            self.config.from_uqff.is_some()
219        );
220        if let Some(from_uqff) = self.config.from_uqff.clone() {
221            *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
222        }
223        *self
224            .token_source
225            .write()
226            .expect("Failed to write to token source") = Some(token_source);
227        *self.revision.write().expect("Failed to write to revision") = revision;
228        self.load_model_from_path(
229            &paths?,
230            dtype,
231            device,
232            silent,
233            mapper,
234            in_situ_quant,
235            paged_attn_config,
236        )
237    }
238
239    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
240    fn load_model_from_path(
241        &self,
242        paths: &Box<dyn ModelPaths>,
243        dtype: &dyn TryIntoDType,
244        device: &Device,
245        silent: bool,
246        mut mapper: DeviceMapSetting,
247        in_situ_quant: Option<IsqType>,
248        mut paged_attn_config: Option<PagedAttentionConfig>,
249    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
250        let config = std::fs::read_to_string(paths.get_config_filename())?;
251
252        if !self.inner.supports_paged_attention() {
253            paged_attn_config = None;
254        }
255
256        let use_nccl = mistralrs_quant::distributed::use_nccl();
257
258        let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
259            let payload: WorkerTransferData = serde_json::from_str(&payload)?;
260            let WorkerTransferData::Init { id: _, worker_rank } = payload;
261            vec![candle_core::Device::new_cuda_with_stream(worker_rank + 1)?]
262        } else if use_nccl {
263            vec![candle_core::Device::new_cuda_with_stream(0)?]
264        } else {
265            device_map::get_all_similar_devices(device)?
266        };
267        let device = if use_nccl {
268            available_devices[0].clone()
269        } else {
270            device.clone()
271        };
272
273        // If auto, convert to Map if not using nccl
274        if use_nccl {
275            mapper = DeviceMapSetting::DummyNccl {
276                nm_device: available_devices[0].clone(),
277            };
278        } else if let DeviceMapSetting::Auto(params) = mapper.clone() {
279            // Initial dtype
280            let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
281
282            // ISQ or UQFF: quantized path
283            // Match logic below where UQFF has priority
284            let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
285                if let Some(serialized) = &*self.from_uqff.read().unwrap() {
286                    let weight_pack_factor = {
287                        let ser_artifacts = unsafe {
288                            candle_core::safetensors::MmapedSafetensors::new(serialized)?
289                        };
290                        let mut total_pack_factors = 0;
291                        let total_tensors = ser_artifacts.tensors().len();
292                        for (_, artifact) in ser_artifacts.tensors() {
293                            let artifact = artifact.data();
294                            // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
295                            let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
296                            let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
297                            {
298                                QuantizedSerdeType::Hqq => {
299                                    HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
300                                        .pack_factor(dtype)
301                                }
302                                QuantizedSerdeType::Gguf => {
303                                    GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
304                                        .pack_factor(dtype)
305                                }
306                                QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
307                                QuantizedSerdeType::Unquant => 1,
308                                QuantizedSerdeType::Afq => {
309                                    AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
310                                        .pack_factor(dtype)
311                                }
312                            };
313                            total_pack_factors += pack_factor;
314                        }
315
316                        total_pack_factors / total_tensors
317                    };
318
319                    let layer_sizes_in_bytes =
320                        self.inner
321                            .layer_sizes_in_bytes(&config, dtype, weight_pack_factor)?;
322                    let non_mapped_size_in_bytes =
323                        self.inner
324                            .non_mapped_size_in_bytes(&config, dtype, weight_pack_factor)?;
325                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
326                    (
327                        layer_sizes_in_bytes,
328                        non_mapped_size_in_bytes,
329                        layer_sizes_sum + non_mapped_size_in_bytes,
330                    )
331                } else if let Some(isq) = in_situ_quant {
332                    let weight_pack_factor = isq.pack_factor(dtype);
333                    let layer_sizes_in_bytes =
334                        self.inner
335                            .layer_sizes_in_bytes(&config, dtype, weight_pack_factor)?;
336                    let non_mapped_size_in_bytes =
337                        self.inner
338                            .non_mapped_size_in_bytes(&config, dtype, weight_pack_factor)?;
339                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
340                    (
341                        layer_sizes_in_bytes,
342                        non_mapped_size_in_bytes,
343                        layer_sizes_sum + non_mapped_size_in_bytes,
344                    )
345                } else {
346                    let layer_sizes_in_bytes =
347                        self.inner.layer_sizes_in_bytes(&config, dtype, 1)?;
348                    let non_mapped_size_in_bytes =
349                        self.inner.non_mapped_size_in_bytes(&config, dtype, 1)?;
350                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
351                    (
352                        layer_sizes_in_bytes,
353                        non_mapped_size_in_bytes,
354                        layer_sizes_sum + non_mapped_size_in_bytes,
355                    )
356                };
357
358            let new = self.inner.get_device_layers(
359                &config,
360                self.inner.num_layers(&config)?,
361                layer_sizes_in_bytes,
362                non_mapped_size_in_bytes,
363                total_model_size_in_bytes,
364                &available_devices,
365                dtype,
366                &params,
367                params.max_seq_len(),
368                paged_attn_config.as_ref(),
369            )?;
370            mapper = DeviceMapSetting::Map(new);
371        }
372
373        let pipeline_mapper = mapper.into_mapper(
374            self.inner.num_layers(&config)?,
375            &device,
376            self.config.topology.as_ref(),
377        )?;
378        let mapper = mapper.into_mapper(
379            self.inner.num_layers(&config)?,
380            &device,
381            self.config.topology.as_ref(),
382        )?;
383        let mut layer_devices = Vec::new();
384        for layer in 0..self.inner.num_layers(&config)? {
385            let device = mapper.device_for(layer, false).cloned();
386            layer_devices.push(device);
387        }
388        let dtype = mapper.get_min_dtype(dtype)?;
389
390        // TODO: PagedAttention is not supported with CPU for now.
391        // This check is not really necessary because `get_device_layers` should prevent it.
392        let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
393        if mapping_uses_cpu {
394            warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
395            paged_attn_config = None;
396        }
397
398        info!(
399            "Model config: {:?}",
400            self.inner
401                .get_config_repr(&config, self.config.use_flash_attn)?
402        );
403
404        let mut loading_isq = in_situ_quant.is_some() || self.config.from_uqff.is_some();
405        if let Some(ref topology) = self.config.topology {
406            loading_isq |= topology
407                .0
408                .iter()
409                .any(|layer| layer.as_ref().is_some_and(|layer| layer.isq.is_some()));
410        }
411
412        if self.config.imatrix.is_some() && self.config.calibration_file.is_some() {
413            anyhow::bail!(
414                "`imatrix` and `calibration_file` were both specified, this is not allowed."
415            );
416        }
417
418        // Load onto the regular device if not using isq or if the calibration file is specified
419        let load_device = if !loading_isq || self.config.calibration_file.is_some() {
420            loading_isq = false;
421            device.clone()
422        } else {
423            Device::Cpu
424        };
425
426        let attention_mechanism = if paged_attn_config.is_some() {
427            AttentionImplementation::PagedAttention
428        } else {
429            AttentionImplementation::Eager
430        };
431
432        let multi_progress = Arc::new(MultiProgress::new());
433
434        let mut model = if use_nccl {
435            let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
436                dtype,
437                &device,
438                &load_device,
439                &available_devices,
440                &config,
441                loading_isq,
442                self.config.from_uqff.is_some(),
443                IsqOrganization::Default,
444                &*self.inner,
445                paths.as_ref(),
446            )?;
447
448            // Special case for where things can be more optimially loaded.
449            match self.kind {
450                ModelKind::Normal => vision_normal_model_loader_sharded!(
451                    sharded_vb,
452                    config,
453                    self.inner,
454                    self.config.use_flash_attn,
455                    mapper,
456                    loading_isq,
457                    device.clone(),
458                    attention_mechanism,
459                    multi_progress.clone(),
460                ),
461                _ => unreachable!(),
462            }
463        } else {
464            match self.kind {
465                ModelKind::Normal => vision_normal_model_loader!(
466                    paths,
467                    Some(dtype),
468                    &load_device,
469                    layer_devices.clone(),
470                    config,
471                    self.inner,
472                    self.config.use_flash_attn,
473                    silent,
474                    mapper,
475                    loading_isq,
476                    self.config.from_uqff.is_some(),
477                    device.clone(),
478                    attention_mechanism,
479                    multi_progress,
480                ),
481                _ => unreachable!(),
482            }
483        };
484
485        // Handle the Gemma 3 1b case here
486        let preprocessor_config: PreProcessorConfig = match paths.get_preprocessor_config().as_ref()
487        {
488            Some(preprocessor_config) => {
489                serde_json::from_str(&fs::read_to_string(preprocessor_config).unwrap()).unwrap()
490            }
491            None => PreProcessorConfig::default(),
492        };
493        let processor_config: Option<ProcessorConfig> = paths
494            .get_processor_config()
495            .as_ref()
496            .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
497
498        let processor = self.inner.get_processor(
499            &config,
500            processor_config,
501            preprocessor_config.clone(),
502            self.config.max_edge,
503        ); //There are always some repos that don't properly handle config position, for example... LLaVA
504
505        let tokenizer = get_tokenizer(
506            paths.get_tokenizer_filename(),
507            Some(processor.get_special_tokens()),
508        )?;
509
510        let gen_conf: Option<GenerationConfig> = paths.get_gen_conf_filename().map(|f| {
511            serde_json::from_str(&fs::read_to_string(f).unwrap())
512                .expect("bos_token_id/eos_token_id missing in generation_config.json")
513        });
514        let chat_template = get_chat_template(
515            paths,
516            &self.jinja_explicit,
517            &paths
518                .get_chat_template_explicit()
519                .as_ref()
520                .map(|x| x.to_string_lossy().to_string())
521                .clone(),
522            &self.chat_template,
523            None,
524        );
525
526        if let Some(calibration_file) = &self.config.calibration_file {
527            let calibration_data = std::fs::read_to_string(calibration_file)?;
528            // Tokenize, don't add bos yet
529            let tokens = tokenizer
530                .encode_fast(calibration_data, false)
531                .map_err(anyhow::Error::msg)?
532                .get_ids()
533                .to_vec();
534            info!(
535                "Collecting imatrix from calibration file `{}` of {} tokens.",
536                calibration_file.display(),
537                tokens.len()
538            );
539            let bos_toks = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();
540            let bos_tok_id = tokenizer
541                .token_to_id(&bos_toks[0])
542                .expect("Somehow the bos token is not present.");
543
544            // NOTE: We ONLY calibrate the text bits of these models!!
545            // So only those should be tracked!
546            model.begin_track_stats()?;
547
548            const CHUNK_SIZE: usize = 1024;
549            let n_chunks: usize = tokens.len().div_ceil(CHUNK_SIZE);
550            let start = Instant::now();
551            for (i, chunk) in tokens.chunks(CHUNK_SIZE).enumerate() {
552                let chunk = [vec![bos_tok_id], chunk.to_vec()].concat();
553                let chunk_len = chunk.len();
554
555                let start = Instant::now();
556                let inputs =
557                    make_prompt_chunk(0, vec![chunk], &[0], &load_device, None, false, None, None)?;
558                let _ = model.forward(
559                    &inputs.input,
560                    None, // NOTE: We ONLY calibrate the text bits of these models!!
561                    &inputs.positions,
562                    inputs.context_lens,
563                    inputs.position_ids,
564                    model.default_model_specific_args(&inputs.input),
565                    None,
566                    &inputs.flash_meta,
567                )?;
568                match model.cache_mut() {
569                    EitherCache::Full(full) => {
570                        for layer in &mut *full.lock() {
571                            *layer = None
572                        }
573                    }
574                    EitherCache::Normal(normal) => {
575                        for layer in &mut *normal.lock().unwrap().0 {
576                            layer.reset();
577                        }
578                    }
579                }
580                let end = Instant::now();
581                info!(
582                    "Processed chunk {}/{n_chunks} ({chunk_len} tokens), {:.2}s",
583                    i + 1,
584                    end.duration_since(start).as_secs_f32()
585                );
586            }
587            load_device.synchronize()?;
588            let end = Instant::now();
589            info!(
590                "Finished collecting imatrix in {:.2}s",
591                end.duration_since(start).as_secs_f32()
592            );
593        }
594
595        if (in_situ_quant.is_some() || self.config.topology.is_some())
596            && self.config.from_uqff.is_none()
597        {
598            let imatrix_source = match (
599                self.config.imatrix.as_ref(),
600                self.config.calibration_file.is_some(),
601            ) {
602                (None, false) => None,
603                (Some(file), false) => Some(ImatrixDataSource::File(file)),
604                (None, true) => Some(ImatrixDataSource::Collected),
605                (Some(_), true) => unreachable!(),
606            };
607            model.quantize(
608                in_situ_quant,
609                device.clone(),
610                self.config.topology.as_ref(),
611                silent,
612                imatrix_source,
613                IsqOrganization::Default,
614                self.config.write_uqff.as_ref(),
615                UqffFullSer {
616                    tokenizer: &tokenizer,
617                    template_filename: paths.get_template_filename(),
618                    generation_config: paths.get_gen_conf_filename(),
619                    config: config.clone(),
620                    processor_filename: paths.get_processor_config(),
621                    preprocessor_filename: paths.get_preprocessor_config(),
622                },
623                Arc::new(MultiProgress::new()),
624            )?;
625        } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
626            model.load_from_artifacts(
627                device.clone(),
628                self.config.topology.as_ref(),
629                silent,
630                from_uqff,
631            )?;
632        }
633
634        let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
635            anyhow::ensure!(
636                !matches!(self.kind, ModelKind::Adapter { .. }),
637                "PagedAttention does not support adapter models."
638            );
639            let cache_config = calculate_cache_config(
640                paged_attn_config.mem_gpu,
641                paged_attn_config.mem_cpu,
642                paged_attn_config.block_size,
643                dtype,
644                model.config(),
645                &device,
646                &layer_devices,
647                silent,
648            )?;
649            let cache_engine =
650                CacheEngine::new(model.config(), &cache_config, dtype, &device, layer_devices)?;
651            (Some(cache_config), Some(cache_engine))
652        } else {
653            (None, None)
654        };
655
656        let max_seq_len = model.max_seq_len();
657        let tok_env = build_tok_env(tokenizer.clone());
658        let num_hidden_layers = match model.cache() {
659            EitherCache::Full(full) => full.lock().len(),
660            EitherCache::Normal(normal) => normal.lock().unwrap().0.len(),
661        };
662        let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
663        let sliding_window = model.config().sliding_window;
664        let model_metadata = Arc::new(model.config().clone());
665        Ok(Arc::new(Mutex::new(VisionPipeline {
666            model,
667            tokenizer: tokenizer.into(),
668            chat_template: Arc::new(chat_template),
669            model_id: self.model_id.clone(),
670            metadata: Arc::new(GeneralMetadata {
671                max_seq_len,
672                tok_env: Some(tok_env),
673                is_xlora: false,
674                num_hidden_layers,
675                eos_tok: eos,
676                kind: self.kind.clone(),
677                no_kv_cache: false,
678                no_prefix_cache: false,
679                activation_dtype: dtype,
680                sliding_window,
681                cache_config,
682                cache_engine,
683                prompt_chunksize: self.config.prompt_chunksize,
684                model_metadata: Some(model_metadata),
685            }),
686            processor,
687            prefixer: self.inner.prefixer(),
688            preprocessor_config: Arc::new(preprocessor_config),
689            topology: self.config.topology.clone(),
690            silent,
691            template_filename: paths.get_template_filename().clone(),
692            generation_config: paths.get_gen_conf_filename().cloned(),
693            config,
694            processor_filename: paths.get_processor_config().clone(),
695            preprocessor_filename: paths.get_preprocessor_config().clone(),
696            mapper: pipeline_mapper,
697            imatrix: self.config.imatrix.clone(),
698        })))
699    }
700
701    fn get_id(&self) -> String {
702        self.model_id.to_string()
703    }
704
705    fn get_kind(&self) -> ModelKind {
706        self.kind.clone()
707    }
708}
709
710impl PreProcessingMixin for VisionPipeline {
711    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
712        Some(self.chat_template.clone())
713    }
714    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
715        Some(self.preprocessor_config.clone())
716    }
717    fn get_processor(&self) -> Arc<dyn super::Processor> {
718        self.processor.clone()
719    }
720}
721
722impl IsqPipelineMixin for VisionPipeline {
723    fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
724        let device = self.device().clone();
725        self.model
726            .quantize(
727                Some(dtype),
728                device,
729                self.topology.as_ref(),
730                self.silent,
731                self.imatrix.as_ref().map(ImatrixDataSource::File),
732                IsqOrganization::Default,
733                None,
734                UqffFullSer {
735                    tokenizer: &self.tokenizer,
736                    template_filename: &self.template_filename,
737                    generation_config: self.generation_config.as_ref(),
738                    config: self.config.clone(),
739                    processor_filename: &self.processor_filename,
740                    preprocessor_filename: &self.preprocessor_filename,
741                },
742                Arc::new(MultiProgress::new()),
743            )
744            .map_err(anyhow::Error::msg)
745    }
746}
747
748impl CacheManagerMixin for VisionPipeline {
749    fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
750        if matches!(self.model.cache(), EitherCache::Full(_)) {
751            FullCacheManager.clone_in_cache(self, seqs, false)
752        } else {
753            NormalCacheManager.clone_in_cache(self, seqs, false)
754        }
755    }
756    fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
757        if matches!(self.model.cache(), EitherCache::Full(_)) {
758            FullCacheManager.clone_out_cache(self, seqs, false)
759        } else {
760            NormalCacheManager.clone_out_cache(self, seqs, false)
761        }
762    }
763    fn set_none_cache(
764        &self,
765        seqs: &mut [&mut Sequence],
766        reset_non_granular: bool,
767        modify_draft_cache: bool,
768
769        load_preallocated_cache: bool,
770    ) {
771        if matches!(self.model.cache(), EitherCache::Full(_)) {
772            FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false);
773        } else {
774            NormalCacheManager.set_none_cache(
775                self,
776                seqs,
777                modify_draft_cache,
778                load_preallocated_cache,
779            );
780        }
781        if reset_non_granular {
782            self.reset_non_granular_state()
783        }
784    }
785    fn cache(&self) -> &EitherCache {
786        self.model.cache()
787    }
788}
789
790impl MetadataMixin for VisionPipeline {
791    fn device(&self) -> Device {
792        self.model.device().clone()
793    }
794    fn get_metadata(&self) -> Arc<GeneralMetadata> {
795        self.metadata.clone()
796    }
797    fn name(&self) -> String {
798        self.model_id.clone()
799    }
800    fn reset_non_granular_state(&self) {}
801    fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
802        Some(self.tokenizer.clone())
803    }
804    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
805        Some(&*self.mapper)
806    }
807}
808
809#[async_trait::async_trait]
810impl Pipeline for VisionPipeline {
811    fn forward_inputs(
812        &mut self,
813        inputs: Box<dyn Any>,
814        return_raw_logits: bool,
815    ) -> candle_core::Result<ForwardInputsResult> {
816        let ModelInputs {
817            input_ids,
818            seqlen_offsets,
819            context_lens,
820            position_ids,
821            pixel_values,
822            model_specific_args,
823            paged_attn_meta,
824            flash_meta,
825        } = *inputs.downcast::<ModelInputs>().expect("Downcast failed.");
826        let metadata = self.get_metadata();
827        let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
828            (Some(engine), Some(meta)) => Some((engine.get_kv_cache().clone(), meta)),
829            (Some(_), None) => {
830                // This can happen if Rust-side user code is wrong
831                candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
832            }
833            (None, Some(_)) => {
834                // This should never happen but we handle it anyway
835                candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
836            }
837            (None, None) => None,
838        };
839        #[cfg(feature = "metal")]
840        let logits = objc::rc::autoreleasepool(|| {
841            self.model.forward(
842                &input_ids,
843                pixel_values,
844                &seqlen_offsets,
845                context_lens,
846                position_ids,
847                model_specific_args,
848                paged_attn_meta,
849                &flash_meta,
850            )
851        })?;
852        #[cfg(not(feature = "metal"))]
853        let logits = self.model.forward(
854            &input_ids,
855            pixel_values,
856            &seqlen_offsets,
857            context_lens,
858            position_ids,
859            model_specific_args,
860            paged_attn_meta,
861            &flash_meta,
862        )?;
863        if return_raw_logits {
864            Ok(ForwardInputsResult::RawLogits { logits })
865        } else {
866            Ok(ForwardInputsResult::CausalGeneration { logits })
867        }
868    }
869    async fn sample_causal_gen(
870        &self,
871        seqs: &mut [&mut Sequence],
872        logits: Vec<Tensor>,
873        prefix_cacher: &mut PrefixCacheManagerV2,
874        disable_eos_stop: bool,
875        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
876    ) -> Result<(), candle_core::Error> {
877        sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
878    }
879    fn category(&self) -> ModelCategory {
880        let has_conv2d = self.model.has_conv2d();
881        ModelCategory::Vision {
882            has_conv2d,
883            prefixer: self.prefixer.clone(),
884        }
885    }
886}
887
888impl AnyMoePipelineMixin for VisionPipeline {
889    fn amoe_finish_training(&mut self, gate_model_id: Option<String>) -> candle_core::Result<()> {
890        self.model.finish_training(gate_model_id)
891    }
892    fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
893        self.model.get_vars()
894    }
895    fn amoe_base_model_trainable_params(&self) -> usize {
896        self.model.trainable_params()
897    }
898    fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
899        self.model.take_cached_gating_outputs()
900    }
901    fn amoe_create_layers(
902        &mut self,
903        model_ids: Vec<String>,
904        token: &TokenSource,
905        revision: Option<String>,
906        match_regex: &str,
907        config: crate::amoe::AnyMoeConfig,
908        dtype: candle_core::DType,
909        dev: &Device,
910        (prefix, mlp): (String, String),
911        layers: Vec<usize>,
912        expert_type: AnyMoeExpertType,
913        silent: bool,
914        gate_model_id: Option<String>,
915    ) -> candle_core::Result<()> {
916        let mut vbs = Vec::new();
917        // Precompile regex here
918        let regex = Regex::new(match_regex).map_err(candle_core::Error::msg)?;
919        for model_id in model_ids {
920            let model_id_str = &model_id;
921            let model_id = Path::new(&model_id);
922
923            let api = {
924                let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
925                let mut api = ApiBuilder::from_cache(cache)
926                    .with_progress(!silent)
927                    .with_token(get_token(token).map_err(candle_core::Error::msg)?);
928                if let Ok(x) = std::env::var("HF_HUB_CACHE") {
929                    api = api.with_cache_dir(x.into());
930                }
931                api.build().map_err(candle_core::Error::msg)?
932            };
933            let revision = revision.clone().unwrap_or("main".to_string());
934            let api = api.repo(Repo::with_revision(
935                model_id_str.clone(),
936                RepoType::Model,
937                revision.clone(),
938            ));
939
940            let mut filenames = vec![];
941            for rfilename in api_dir_list!(api, model_id).filter(|x| x.ends_with(".safetensors")) {
942                filenames.push(api_get_file!(api, &rfilename, model_id));
943            }
944
945            let regex = regex.clone();
946            let match_regex_clone = match_regex.to_string();
947            let layers_clone = layers.clone();
948            let vb = from_mmaped_safetensors(
949                filenames,
950                vec![],
951                Some(dtype),
952                dev,
953                vec![None],
954                silent,
955                None,
956                move |key| {
957                    if regex.is_match(&key) {
958                        // Idx of the last char of the layer id, +1
959                        // Assumes N.MLP
960                        let last_layer_idx = key.find(&match_regex_clone).unwrap() - 1;
961                        let first_layer_idx = key[..last_layer_idx].rfind('.').unwrap();
962                        let layer_n = key[first_layer_idx + 1..last_layer_idx]
963                            .parse::<usize>()
964                            .unwrap();
965                        layers_clone.contains(&layer_n) || layers_clone.is_empty()
966                    } else {
967                        false
968                    }
969                },
970                Arc::new(|_| DeviceForLoadTensor::Base),
971            )?;
972            vbs.push(vb);
973        }
974
975        let gate_vb = if let Some(gate_model_id) = gate_model_id {
976            let model_id_str = &gate_model_id;
977            let model_id = Path::new(&gate_model_id);
978
979            let api = {
980                let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
981                let mut api = ApiBuilder::from_cache(cache)
982                    .with_progress(!silent)
983                    .with_token(get_token(token).map_err(candle_core::Error::msg)?);
984                if let Ok(x) = std::env::var("HF_HUB_CACHE") {
985                    api = api.with_cache_dir(x.into());
986                }
987                api.build().map_err(candle_core::Error::msg)?
988            };
989            let revision = revision.clone().unwrap_or("main".to_string());
990            let api = api.repo(Repo::with_revision(
991                model_id_str.clone(),
992                RepoType::Model,
993                revision.clone(),
994            ));
995
996            let mut gate_filenames = vec![];
997            for rfilename in api_dir_list!(api, model_id).filter(|x| x.ends_with(".safetensors")) {
998                gate_filenames.push(api_get_file!(api, &rfilename, model_id));
999            }
1000            assert_eq!(
1001                gate_filenames.len(),
1002                1,
1003                "Gate model ID must contain only one .safetensors file"
1004            );
1005
1006            let vb = from_mmaped_safetensors(
1007                gate_filenames.clone(),
1008                vec![],
1009                Some(dtype),
1010                dev,
1011                vec![None],
1012                silent,
1013                None,
1014                |_| true,
1015                Arc::new(|_| DeviceForLoadTensor::Base),
1016            )?;
1017            info!(
1018                "Loaded gating layers from `{}`",
1019                gate_filenames[0].display()
1020            );
1021            Some(vb)
1022        } else {
1023            None
1024        };
1025
1026        self.model
1027            .create_anymoe_layers(vbs, config, (prefix, mlp), layers, expert_type, gate_vb)
1028    }
1029    fn amoe_supported(&self) -> bool {
1030        self.model.amoe_supported()
1031    }
1032}