1use super::isq::ImatrixDataSource;
2use super::isq::UqffFullSer;
3use super::{
4 get_model_paths, get_xlora_paths, AdapterKind, AnyMoePipelineMixin, AutoVisionLoader,
5 CacheManager, CacheManagerMixin, EitherCache, ForwardInputsResult, Gemma3Loader,
6 GeneralMetadata, IsqPipelineMixin, Loader, MetadataMixin, MiniCpmOLoader, ModelCategory,
7 ModelKind, ModelPaths, MultimodalPromptPrefixer, Phi4MMLoader, PreProcessingMixin, Processor,
8 Qwen2VLLoader, TokenSource, VLlama4Loader, VLlamaLoader, VisionModel, VisionModelLoader,
9};
10use super::{
11 Gemma3nLoader, Idefics2Loader, Idefics3Loader, LLaVALoader, LLaVANextLoader, Mistral3Loader,
12 Phi3VLoader, Qwen2_5VLLoader, VisionLoaderType,
13};
14use crate::device_map::{self, DeviceMapper};
15use crate::distributed::{self, WorkerTransferData};
16use crate::kv_cache::{FullCacheManager, NormalCacheManager};
17use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine};
18use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
19use crate::pipeline::llg::build_llg_factory;
20use crate::pipeline::loaders::auto_device_map;
21use crate::pipeline::loaders::QuantizationConfigShim;
22use crate::pipeline::sampling::sample_and_add_toks;
23use crate::pipeline::text_models_inputs_processor::make_prompt_chunk;
24use crate::pipeline::{get_chat_template, ChatTemplate, IsqOrganization, LocalModelPaths};
25use crate::prefix_cacher::PrefixCacheManagerV2;
26use crate::sequence::Sequence;
27use crate::utils::tokenizer::get_tokenizer;
28use crate::utils::varbuilder_utils::DeviceForLoadTensor;
29use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
30use crate::vision_models::preprocessor_config::PreProcessorConfig;
31use crate::vision_models::processor_config::ProcessorConfig;
32use crate::vision_models::ModelInputs;
33use crate::{
34 api_dir_list, api_get_file, get_paths, get_uqff_paths, vision_normal_model_loader,
35 vision_normal_model_loader_sharded, AnyMoeExpertType, DeviceMapSetting, Ordering,
36 PagedAttentionConfig, Pipeline, Topology, TryIntoDType, GLOBAL_HF_CACHE,
37};
38use anyhow::Result;
39use candle_core::{Device, Tensor, Var};
40use hf_hub::Cache;
41use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
42use indicatif::MultiProgress;
43use mistralrs_quant::log::once_log_info;
44use mistralrs_quant::{AfqLayer, GgufMatMul, HqqLayer, IsqType, QuantizedSerdeType};
45use rand_isaac::Isaac64Rng;
46use regex_automata::meta::Regex;
47use std::any::Any;
48use std::borrow::Cow;
49use std::num::NonZeroUsize;
50use std::path::{Path, PathBuf};
51use std::str::FromStr;
52use std::sync::{Arc, RwLock};
53use std::time::Instant;
54use std::{env, fs};
55use tokenizers::Tokenizer;
56use tokio::sync::Mutex;
57use tracing::{info, warn};
58
59pub struct VisionPipeline {
60 model: Box<dyn VisionModel + Send + Sync>,
61 tokenizer: Arc<Tokenizer>,
62 chat_template: Arc<ChatTemplate>,
63 model_id: String,
64 metadata: Arc<GeneralMetadata>,
65 processor: Arc<dyn Processor + Send + Sync>,
66 preprocessor_config: Arc<PreProcessorConfig>,
67 topology: Option<Topology>,
68 silent: bool,
69 prefixer: Arc<dyn MultimodalPromptPrefixer>,
70 mapper: Box<dyn DeviceMapper + Send + Sync>,
71
72 template_filename: Option<PathBuf>,
74 generation_config: Option<PathBuf>,
75 config: String,
76 processor_filename: Option<PathBuf>,
77 preprocessor_filename: Option<PathBuf>,
78 imatrix: Option<PathBuf>,
79}
80
81pub struct VisionLoader {
83 inner: Box<dyn VisionModelLoader>,
84 model_id: String,
85 config: VisionSpecificConfig,
86 kind: ModelKind,
87 chat_template: Option<String>,
88 tokenizer_json: Option<String>,
89 xlora_model_id: Option<String>,
90 xlora_order: Option<Ordering>,
91 token_source: RwLock<Option<TokenSource>>,
92 revision: RwLock<Option<String>>,
93 from_uqff: RwLock<Option<Vec<PathBuf>>>,
94 jinja_explicit: Option<String>,
95 hf_cache_path: Option<PathBuf>,
96 lora_adapter_ids: Option<Vec<String>>,
97}
98
99#[derive(Default)]
100pub struct VisionLoaderBuilder {
102 model_id: Option<String>,
103 config: VisionSpecificConfig,
104 kind: ModelKind,
105 chat_template: Option<String>,
106 tokenizer_json: Option<String>,
107 jinja_explicit: Option<String>,
108 hf_cache_path: Option<PathBuf>,
109 lora_adapter_ids: Option<Vec<String>>,
110}
111
112#[derive(Clone, Default)]
113pub struct VisionSpecificConfig {
115 pub prompt_chunksize: Option<NonZeroUsize>,
116 pub topology: Option<Topology>,
117 pub write_uqff: Option<PathBuf>,
118 pub from_uqff: Option<Vec<PathBuf>>,
119 pub max_edge: Option<u32>,
120 pub imatrix: Option<PathBuf>,
121 pub calibration_file: Option<PathBuf>,
122 pub hf_cache_path: Option<PathBuf>,
123 pub matformer_config_path: Option<PathBuf>,
124 pub matformer_slice_name: Option<String>,
125}
126
127impl VisionLoaderBuilder {
128 pub fn new(
129 config: VisionSpecificConfig,
130 chat_template: Option<String>,
131 tokenizer_json: Option<String>,
132 model_id: Option<String>,
133 jinja_explicit: Option<String>,
134 ) -> Self {
135 Self {
136 config,
137 chat_template,
138 tokenizer_json,
139 model_id,
140 jinja_explicit,
141 kind: ModelKind::Normal,
142 hf_cache_path: None,
143 ..Default::default()
144 }
145 }
146
147 pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
148 self.hf_cache_path = Some(hf_cache_path);
149 self
150 }
151
152 pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
153 self.kind = ModelKind::Adapter {
154 adapter: AdapterKind::Lora,
155 };
156 self.lora_adapter_ids = Some(lora_adapter_ids);
157 self
158 }
159
160 pub fn build(self, loader: Option<VisionLoaderType>) -> Box<dyn Loader> {
161 let loader: Box<dyn VisionModelLoader> = match loader {
162 Some(VisionLoaderType::Phi3V) => Box::new(Phi3VLoader),
163 Some(VisionLoaderType::Idefics2) => Box::new(Idefics2Loader),
164 Some(VisionLoaderType::LLaVANext) => Box::new(LLaVANextLoader),
165 Some(VisionLoaderType::LLaVA) => Box::new(LLaVALoader),
166 Some(VisionLoaderType::VLlama) => Box::new(VLlamaLoader),
167 Some(VisionLoaderType::Qwen2VL) => Box::new(Qwen2VLLoader),
168 Some(VisionLoaderType::Idefics3) => Box::new(Idefics3Loader),
169 Some(VisionLoaderType::MiniCpmO) => Box::new(MiniCpmOLoader),
170 Some(VisionLoaderType::Phi4MM) => Box::new(Phi4MMLoader),
171 Some(VisionLoaderType::Qwen2_5VL) => Box::new(Qwen2_5VLLoader),
172 Some(VisionLoaderType::Gemma3) => Box::new(Gemma3Loader),
173 Some(VisionLoaderType::Mistral3) => Box::new(Mistral3Loader),
174 Some(VisionLoaderType::Llama4) => Box::new(VLlama4Loader),
175 Some(VisionLoaderType::Gemma3n) => Box::new(Gemma3nLoader),
176 None => Box::new(AutoVisionLoader),
177 };
178 Box::new(VisionLoader {
179 inner: loader,
180 model_id: self.model_id.unwrap(),
181 config: self.config,
182 kind: self.kind,
183 chat_template: self.chat_template,
184 tokenizer_json: self.tokenizer_json,
185 xlora_model_id: None,
186 xlora_order: None,
187 jinja_explicit: self.jinja_explicit,
188 token_source: RwLock::new(None),
189 revision: RwLock::new(None),
190 from_uqff: RwLock::new(None),
191 hf_cache_path: self.hf_cache_path,
192 lora_adapter_ids: self.lora_adapter_ids,
193 })
194 }
195}
196
197impl Loader for VisionLoader {
198 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
199 fn load_model_from_hf(
200 &self,
201 revision: Option<String>,
202 token_source: TokenSource,
203 dtype: &dyn TryIntoDType,
204 device: &Device,
205 silent: bool,
206 mapper: DeviceMapSetting,
207 in_situ_quant: Option<IsqType>,
208 paged_attn_config: Option<PagedAttentionConfig>,
209 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
210 let cache = self
211 .hf_cache_path
212 .clone()
213 .map(Cache::new)
214 .unwrap_or_default();
215 GLOBAL_HF_CACHE.get_or_init(|| cache);
216
217 let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths!(
218 LocalModelPaths,
219 &token_source,
220 revision.clone(),
221 self,
222 None,
223 None,
224 silent,
225 self.config.from_uqff.is_some()
226 );
227 if let Some(from_uqff) = self.config.from_uqff.clone() {
228 *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
229 }
230 *self
231 .token_source
232 .write()
233 .expect("Failed to write to token source") = Some(token_source);
234 *self.revision.write().expect("Failed to write to revision") = revision;
235 self.load_model_from_path(
236 &paths?,
237 dtype,
238 device,
239 silent,
240 mapper,
241 in_situ_quant,
242 paged_attn_config,
243 )
244 }
245
246 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
247 fn load_model_from_path(
248 &self,
249 paths: &Box<dyn ModelPaths>,
250 dtype: &dyn TryIntoDType,
251 device: &Device,
252 silent: bool,
253 mut mapper: DeviceMapSetting,
254 mut in_situ_quant: Option<IsqType>,
255 mut paged_attn_config: Option<PagedAttentionConfig>,
256 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
257 let config = std::fs::read_to_string(paths.get_config_filename())?;
258
259 if !self.inner.supports_paged_attention(&config) {
260 paged_attn_config = None;
261 }
262
263 let use_nccl = mistralrs_quant::distributed::use_nccl();
264
265 let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
266 let payload: WorkerTransferData = serde_json::from_str(&payload)?;
267 let WorkerTransferData::Init { id: _, worker_rank } = payload;
268 vec![candle_core::Device::new_cuda_with_stream(worker_rank + 1)?]
269 } else if use_nccl {
270 vec![candle_core::Device::new_cuda_with_stream(0)?]
271 } else {
272 device_map::get_all_similar_devices(device)?
273 };
274 let device = if use_nccl {
275 available_devices[0].clone()
276 } else {
277 device.clone()
278 };
279
280 let matformer_slicing_config = if let Some(matformer_path) =
282 &self.config.matformer_config_path
283 {
284 use crate::matformer::{MatformerConfig, MatformerSliceConfig};
285 info!("Loading Matformer config from {:?}", matformer_path);
286 let config = Arc::new(MatformerConfig::from_file(matformer_path)?);
287
288 if let Some(slice_name) = &self.config.matformer_slice_name {
289 info!("Using Matformer slice: {}", slice_name);
290 Some(MatformerSliceConfig::new(slice_name.clone(), config))
291 } else {
292 warn!("Matformer config loaded but no slice name specified. Models will use their default slice.");
295 None
296 }
297 } else {
298 None
299 };
300
301 if use_nccl {
303 mapper = DeviceMapSetting::DummyNccl {
304 nm_device: available_devices[0].clone(),
305 };
306 } else if let DeviceMapSetting::Auto(mut params) = mapper.clone() {
307 params = params.maybe_promote_to_vision();
309
310 let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
312
313 if QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)? != 1 {
315 in_situ_quant = None;
316 }
317
318 let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
321 if let Some(serialized) = &*self.from_uqff.read().unwrap() {
322 let weight_pack_factor = {
323 let ser_artifacts = unsafe {
324 candle_core::safetensors::MmapedSafetensors::multi(serialized)?
325 };
326 let mut total_pack_factors = 0;
327 let total_tensors = ser_artifacts.tensors().len();
328 for (_, artifact) in ser_artifacts.tensors() {
329 let artifact = artifact.data();
330 let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
332 let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
333 {
334 QuantizedSerdeType::Hqq => {
335 HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
336 .pack_factor(dtype)
337 }
338 QuantizedSerdeType::Gguf => {
339 GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
340 .pack_factor(dtype)
341 }
342 QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
343 QuantizedSerdeType::Unquant => 1,
344 QuantizedSerdeType::Afq => {
345 AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
346 .pack_factor(dtype)
347 }
348 };
349 total_pack_factors += pack_factor;
350 }
351
352 total_pack_factors / total_tensors
353 };
354
355 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
356 &config,
357 dtype,
358 weight_pack_factor,
359 matformer_slicing_config.as_ref(),
360 )?;
361 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
362 &config,
363 dtype,
364 weight_pack_factor,
365 matformer_slicing_config.as_ref(),
366 )?;
367 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
368 (
369 layer_sizes_in_bytes,
370 non_mapped_size_in_bytes,
371 layer_sizes_sum + non_mapped_size_in_bytes,
372 )
373 } else if let Some(isq) = in_situ_quant {
374 let weight_pack_factor = isq.pack_factor(dtype);
375 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
376 &config,
377 dtype,
378 weight_pack_factor,
379 matformer_slicing_config.as_ref(),
380 )?;
381 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
382 &config,
383 dtype,
384 weight_pack_factor,
385 matformer_slicing_config.as_ref(),
386 )?;
387 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
388 (
389 layer_sizes_in_bytes,
390 non_mapped_size_in_bytes,
391 layer_sizes_sum + non_mapped_size_in_bytes,
392 )
393 } else {
394 let weight_pack_factor =
396 QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)?;
397 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
398 &config,
399 dtype,
400 weight_pack_factor,
401 matformer_slicing_config.as_ref(),
402 )?;
403 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
404 &config,
405 dtype,
406 weight_pack_factor,
407 matformer_slicing_config.as_ref(),
408 )?;
409 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
410 (
411 layer_sizes_in_bytes,
412 non_mapped_size_in_bytes,
413 layer_sizes_sum + non_mapped_size_in_bytes,
414 )
415 };
416
417 let new = auto_device_map::get_device_layers(
418 &*self.inner,
419 &config,
420 self.inner.num_layers(&config)?,
421 layer_sizes_in_bytes,
422 non_mapped_size_in_bytes,
423 total_model_size_in_bytes,
424 &available_devices,
425 dtype,
426 ¶ms,
427 params.max_seq_len(),
428 paged_attn_config.as_ref(),
429 )?;
430 mapper = DeviceMapSetting::Map(new);
431 }
432
433 let pipeline_mapper = mapper.into_mapper(
434 self.inner.num_layers(&config)?,
435 &device,
436 self.config.topology.as_ref(),
437 )?;
438 let mapper = mapper.into_mapper(
439 self.inner.num_layers(&config)?,
440 &device,
441 self.config.topology.as_ref(),
442 )?;
443 let mut layer_devices = Vec::new();
444 for layer in 0..self.inner.num_layers(&config)? {
445 let device = mapper.device_for(layer, false).cloned();
446 layer_devices.push(device);
447 }
448 let dtype = mapper.get_min_dtype(dtype)?;
449
450 let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
453 if mapping_uses_cpu && paged_attn_config.is_some() {
454 warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
455 paged_attn_config = None;
456 }
457
458 info!("Model config: {:?}", self.inner.get_config_repr(&config)?);
459 if crate::using_flash_attn() {
460 once_log_info("FlashAttention is enabled.");
461 }
462
463 let mut loading_isq = if self.config.imatrix.is_none()
465 && self.config.calibration_file.is_none()
466 && !device.is_cuda()
467 && self.config.write_uqff.is_none()
468 && in_situ_quant.is_some()
469 {
470 let predicates = self.inner.immediate_isq_predicates(&config)?;
471 info!("Applying ISQ to {in_situ_quant:?}");
472 if predicates.is_empty() {
473 warn!("No predicates for this model and ISQ setting detected. ISQ will not be applied to any weights!");
474 }
475 mistralrs_quant::set_immediate_isq(in_situ_quant, predicates);
476 false
477 } else {
478 in_situ_quant.is_some()
479 };
480
481 if let Some(ref topology) = self.config.topology {
482 loading_isq |= topology
483 .0
484 .iter()
485 .any(|layer| layer.as_ref().is_some_and(|layer| layer.isq.is_some()));
486 }
487
488 if self.config.imatrix.is_some() && self.config.calibration_file.is_some() {
489 anyhow::bail!(
490 "`imatrix` and `calibration_file` were both specified, this is not allowed."
491 );
492 }
493
494 let load_device = if !loading_isq || self.config.calibration_file.is_some() {
496 loading_isq = false;
497 device.clone()
498 } else {
499 Device::Cpu
500 };
501
502 let attention_mechanism = if paged_attn_config.is_some() {
503 AttentionImplementation::PagedAttention
504 } else {
505 AttentionImplementation::Eager
506 };
507
508 let multi_progress = Arc::new(MultiProgress::new());
509
510 let mut model = if use_nccl {
511 let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
512 dtype,
513 &device,
514 &available_devices,
515 silent,
516 &config,
517 loading_isq,
518 self.config.from_uqff.is_some(),
519 IsqOrganization::Default,
520 &*self.inner,
521 paths.as_ref(),
522 )?;
523
524 match self.kind {
526 ModelKind::Normal => vision_normal_model_loader_sharded!(
527 sharded_vb,
528 config,
529 self.inner,
530 mapper,
531 loading_isq,
532 device.clone(),
533 attention_mechanism,
534 multi_progress.clone(),
535 matformer_slicing_config.clone(),
536 ),
537 _ => unreachable!(),
538 }
539 } else {
540 match self.kind {
541 ModelKind::Normal => vision_normal_model_loader!(
542 paths,
543 Some(dtype),
544 &load_device,
545 layer_devices.clone(),
546 config,
547 self.inner,
548 silent,
549 mapper,
550 loading_isq,
551 self.config.from_uqff.is_some(),
552 device.clone(),
553 attention_mechanism,
554 multi_progress,
555 matformer_slicing_config.clone(),
556 ),
557 _ => unreachable!(),
558 }
559 };
560
561 let preprocessor_config: PreProcessorConfig = match paths.get_preprocessor_config().as_ref()
563 {
564 Some(preprocessor_config) => {
565 serde_json::from_str(&fs::read_to_string(preprocessor_config).unwrap()).unwrap()
566 }
567 None => PreProcessorConfig::default(),
568 };
569 let processor_config: Option<ProcessorConfig> = paths
570 .get_processor_config()
571 .as_ref()
572 .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
573
574 let processor = self.inner.get_processor(
575 &config,
576 processor_config,
577 preprocessor_config.clone(),
578 self.config.max_edge,
579 ); let tokenizer = get_tokenizer(
582 paths.get_tokenizer_filename(),
583 Some(processor.get_special_tokens()),
584 )?;
585
586 let gen_conf: Option<GenerationConfig> = paths
587 .get_gen_conf_filename()
588 .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
589 let chat_template_explicit = paths
590 .get_chat_template_explicit()
591 .as_ref()
592 .map(|x| x.to_string_lossy().to_string());
593 let chat_template = get_chat_template(
594 paths,
595 self.jinja_explicit.as_ref(),
596 chat_template_explicit.as_ref(),
597 self.chat_template.as_ref(),
598 None,
599 );
600
601 if let Some(calibration_file) = &self.config.calibration_file {
602 let calibration_data = std::fs::read_to_string(calibration_file)?;
603 let tokens = tokenizer
605 .encode_fast(calibration_data, false)
606 .map_err(anyhow::Error::msg)?
607 .get_ids()
608 .to_vec();
609 info!(
610 "Collecting imatrix from calibration file `{}` of {} tokens.",
611 calibration_file.display(),
612 tokens.len()
613 );
614 let bos_toks = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();
615 let bos_tok_id = tokenizer
616 .token_to_id(&bos_toks[0])
617 .expect("Somehow the bos token is not present.");
618
619 model.begin_track_stats()?;
622
623 const CHUNK_SIZE: usize = 1024;
624 let n_chunks: usize = tokens.len().div_ceil(CHUNK_SIZE);
625 let start = Instant::now();
626 for (i, chunk) in tokens.chunks(CHUNK_SIZE).enumerate() {
627 let chunk = [vec![bos_tok_id], chunk.to_vec()].concat();
628 let chunk_len = chunk.len();
629
630 let start = Instant::now();
631 let inputs = make_prompt_chunk(
632 0,
633 vec![&chunk],
634 &[0],
635 &load_device,
636 None,
637 false,
638 None,
639 None,
640 )?;
641 let _ = model.forward(
642 &inputs.input,
643 None, &inputs.positions,
645 inputs.context_lens,
646 inputs.position_ids,
647 model.default_model_specific_args(&inputs.input),
648 None,
649 &inputs.flash_meta,
650 )?;
651 match model.cache_mut() {
652 EitherCache::Full(full) => {
653 for layer in &mut *full.lock() {
654 *layer = None
655 }
656 }
657 EitherCache::Normal(normal) => {
658 for layer in &mut *normal.lock().unwrap().0 {
659 layer.reset();
660 }
661 }
662 }
663 let end = Instant::now();
664 info!(
665 "Processed chunk {}/{n_chunks} ({chunk_len} tokens), {:.2}s",
666 i + 1,
667 end.duration_since(start).as_secs_f32()
668 );
669 }
670 load_device.synchronize()?;
671 let end = Instant::now();
672 info!(
673 "Finished collecting imatrix in {:.2}s",
674 end.duration_since(start).as_secs_f32()
675 );
676 }
677
678 if (loading_isq || self.config.topology.is_some()) && self.config.from_uqff.is_none() {
680 let imatrix_source = match (
681 self.config.imatrix.as_ref(),
682 self.config.calibration_file.is_some(),
683 ) {
684 (None, false) => None,
685 (Some(file), false) => Some(ImatrixDataSource::File(file)),
686 (None, true) => Some(ImatrixDataSource::Collected),
687 (Some(_), true) => unreachable!(),
688 };
689 model.quantize(
690 in_situ_quant,
691 device.clone(),
692 self.config.topology.as_ref(),
693 silent,
694 imatrix_source,
695 IsqOrganization::Default,
696 self.config.write_uqff.as_ref(),
697 UqffFullSer {
698 tokenizer: &tokenizer,
699 template_filename: paths.get_template_filename(),
700 generation_config: paths.get_gen_conf_filename(),
701 config: config.clone(),
702 processor_filename: paths.get_processor_config(),
703 preprocessor_filename: paths.get_preprocessor_config(),
704 },
705 Arc::new(MultiProgress::new()),
706 )?;
707 } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
708 model.load_from_artifacts(
709 device.clone(),
710 self.config.topology.as_ref(),
711 silent,
712 from_uqff,
713 )?;
714 }
715
716 let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
717 anyhow::ensure!(
718 !matches!(self.kind, ModelKind::Adapter { .. }),
719 "PagedAttention does not support adapter models."
720 );
721 let cache_config = calculate_cache_config(
722 paged_attn_config.mem_gpu,
723 paged_attn_config.mem_cpu,
724 paged_attn_config.block_size,
725 dtype,
726 paged_attn_config.cache_type,
727 model.config(),
728 &device,
729 &layer_devices,
730 silent,
731 )?;
732 let cache_engine =
733 CacheEngine::new(model.config(), &cache_config, dtype, &device, layer_devices)?;
734 (Some(cache_config), Some(cache_engine))
735 } else {
736 (None, None)
737 };
738
739 let max_seq_len = model.max_seq_len();
740 let llg_factory = build_llg_factory(tokenizer.clone())?;
741 let num_hidden_layers = match model.cache() {
742 EitherCache::Full(full) => full.lock().len(),
743 EitherCache::Normal(normal) => normal.lock().unwrap().0.len(),
744 };
745 let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
746 let sliding_window = model.config().sliding_window;
747 let model_metadata = Arc::new(model.config().clone());
748 Ok(Arc::new(Mutex::new(VisionPipeline {
749 model,
750 tokenizer: tokenizer.into(),
751 chat_template: Arc::new(chat_template),
752 model_id: self.model_id.clone(),
753 metadata: Arc::new(GeneralMetadata {
754 max_seq_len,
755 llg_factory: Some(llg_factory),
756 is_xlora: false,
757 num_hidden_layers,
758 eos_tok: eos,
759 kind: self.kind.clone(),
760 no_kv_cache: false,
761 no_prefix_cache: !self.inner.supports_prefix_cacher(&config),
762 activation_dtype: dtype,
763 sliding_window,
764 cache_config,
765 cache_engine,
766 prompt_chunksize: self.config.prompt_chunksize,
767 model_metadata: Some(model_metadata),
768 modalities: self.inner.modalities(&config)?,
769 }),
770 processor,
771 prefixer: self.inner.prefixer(&config),
772 preprocessor_config: Arc::new(preprocessor_config),
773 topology: self.config.topology.clone(),
774 silent,
775 template_filename: paths.get_template_filename().clone(),
776 generation_config: paths.get_gen_conf_filename().cloned(),
777 config,
778 processor_filename: paths.get_processor_config().clone(),
779 preprocessor_filename: paths.get_preprocessor_config().clone(),
780 mapper: pipeline_mapper,
781 imatrix: self.config.imatrix.clone(),
782 })))
783 }
784
785 fn get_id(&self) -> String {
786 self.model_id.to_string()
787 }
788
789 fn get_kind(&self) -> ModelKind {
790 self.kind.clone()
791 }
792}
793
794impl PreProcessingMixin for VisionPipeline {
795 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
796 Some(self.chat_template.clone())
797 }
798 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
799 Some(self.preprocessor_config.clone())
800 }
801 fn get_processor(&self) -> Arc<dyn super::Processor> {
802 self.processor.clone()
803 }
804}
805
806impl IsqPipelineMixin for VisionPipeline {
807 fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
808 let device = self.device().clone();
809 self.model
810 .quantize(
811 Some(dtype),
812 device,
813 self.topology.as_ref(),
814 self.silent,
815 self.imatrix.as_ref().map(ImatrixDataSource::File),
816 IsqOrganization::Default,
817 None,
818 UqffFullSer {
819 tokenizer: &self.tokenizer,
820 template_filename: &self.template_filename,
821 generation_config: self.generation_config.as_ref(),
822 config: self.config.clone(),
823 processor_filename: &self.processor_filename,
824 preprocessor_filename: &self.preprocessor_filename,
825 },
826 Arc::new(MultiProgress::new()),
827 )
828 .map_err(anyhow::Error::msg)
829 }
830}
831
832impl CacheManagerMixin for VisionPipeline {
833 fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
834 if matches!(self.model.cache(), EitherCache::Full(_)) {
835 FullCacheManager.clone_in_cache(self, seqs, false)
836 } else {
837 NormalCacheManager.clone_in_cache(self, seqs, false)
838 }
839 }
840 fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
841 if matches!(self.model.cache(), EitherCache::Full(_)) {
842 FullCacheManager.clone_out_cache(self, seqs, false)
843 } else {
844 NormalCacheManager.clone_out_cache(self, seqs, false)
845 }
846 }
847 fn set_none_cache(
848 &self,
849 seqs: &mut [&mut Sequence],
850 reset_non_granular: bool,
851 modify_draft_cache: bool,
852
853 load_preallocated_cache: bool,
854 ) {
855 if matches!(self.model.cache(), EitherCache::Full(_)) {
856 FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false);
857 } else {
858 NormalCacheManager.set_none_cache(
859 self,
860 seqs,
861 modify_draft_cache,
862 load_preallocated_cache,
863 );
864 }
865 if reset_non_granular {
866 self.reset_non_granular_state()
867 }
868 }
869 fn cache(&self) -> &EitherCache {
870 self.model.cache()
871 }
872}
873
874impl MetadataMixin for VisionPipeline {
875 fn device(&self) -> Device {
876 self.model.device().clone()
877 }
878 fn get_metadata(&self) -> Arc<GeneralMetadata> {
879 self.metadata.clone()
880 }
881 fn name(&self) -> String {
882 self.model_id.clone()
883 }
884 fn reset_non_granular_state(&self) {}
885 fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
886 Some(self.tokenizer.clone())
887 }
888 fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
889 Some(&*self.mapper)
890 }
891}
892
893#[async_trait::async_trait]
894impl Pipeline for VisionPipeline {
895 fn forward_inputs(
896 &mut self,
897 inputs: Box<dyn Any>,
898 return_raw_logits: bool,
899 ) -> candle_core::Result<ForwardInputsResult> {
900 let ModelInputs {
901 input_ids,
902 seqlen_offsets,
903 context_lens,
904 position_ids,
905 pixel_values,
906 model_specific_args,
907 paged_attn_meta,
908 flash_meta,
909 } = *inputs.downcast::<ModelInputs>().expect("Downcast failed.");
910 let metadata = self.get_metadata();
911 let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
912 (Some(engine), Some(meta)) => Some((engine.get_kv_cache().clone(), meta)),
913 (Some(_), None) => {
914 candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
916 }
917 (None, Some(_)) => {
918 candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
920 }
921 (None, None) => None,
922 };
923 let logits = self.model.forward(
924 &input_ids,
925 pixel_values,
926 &seqlen_offsets,
927 context_lens,
928 position_ids,
929 model_specific_args,
930 paged_attn_meta,
931 &flash_meta,
932 )?;
933 if return_raw_logits {
934 Ok(ForwardInputsResult::RawLogits { logits })
935 } else {
936 Ok(ForwardInputsResult::CausalGeneration { logits })
937 }
938 }
939 async fn sample_causal_gen(
940 &self,
941 seqs: &mut [&mut Sequence],
942 logits: Vec<Tensor>,
943 prefix_cacher: &mut PrefixCacheManagerV2,
944 disable_eos_stop: bool,
945 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
946 ) -> Result<(), candle_core::Error> {
947 sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
948 }
949 fn category(&self) -> ModelCategory {
950 ModelCategory::Vision {
951 prefixer: self.prefixer.clone(),
952 }
953 }
954}
955
956impl AnyMoePipelineMixin for VisionPipeline {
957 fn amoe_finish_training(&mut self, gate_model_id: Option<String>) -> candle_core::Result<()> {
958 self.model.finish_training(gate_model_id)
959 }
960 fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
961 self.model.get_vars()
962 }
963 fn amoe_base_model_trainable_params(&self) -> usize {
964 self.model.trainable_params()
965 }
966 fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
967 self.model.take_cached_gating_outputs()
968 }
969 fn amoe_create_layers(
970 &mut self,
971 model_ids: Vec<String>,
972 token: &TokenSource,
973 revision: Option<String>,
974 match_regex: &str,
975 config: crate::amoe::AnyMoeConfig,
976 dtype: candle_core::DType,
977 dev: &Device,
978 (prefix, mlp): (String, String),
979 layers: Vec<usize>,
980 expert_type: AnyMoeExpertType,
981 silent: bool,
982 gate_model_id: Option<String>,
983 ) -> candle_core::Result<()> {
984 let mut vbs = Vec::new();
985 let regex = Regex::new(match_regex).map_err(candle_core::Error::msg)?;
987 for model_id in model_ids {
988 let model_id_str = &model_id;
989 let model_id = Path::new(&model_id);
990
991 let api = {
992 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
993 let mut api = ApiBuilder::from_cache(cache)
994 .with_progress(!silent)
995 .with_token(get_token(token).map_err(candle_core::Error::msg)?);
996 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
997 api = api.with_cache_dir(x.into());
998 }
999 api.build().map_err(candle_core::Error::msg)?
1000 };
1001 let revision = revision.clone().unwrap_or("main".to_string());
1002 let api = api.repo(Repo::with_revision(
1003 model_id_str.clone(),
1004 RepoType::Model,
1005 revision.clone(),
1006 ));
1007
1008 let mut filenames = vec![];
1009 for rfilename in
1010 api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1011 {
1012 filenames.push(api_get_file!(api, &rfilename, model_id));
1013 }
1014
1015 let regex = regex.clone();
1016 let match_regex_clone = match_regex.to_string();
1017 let layers_clone = layers.clone();
1018 let vb = from_mmaped_safetensors(
1019 filenames,
1020 vec![],
1021 Some(dtype),
1022 dev,
1023 vec![None],
1024 silent,
1025 None,
1026 move |key| {
1027 if regex.is_match(&key) {
1028 let last_layer_idx = key.find(&match_regex_clone).unwrap() - 1;
1031 let first_layer_idx = key[..last_layer_idx].rfind('.').unwrap();
1032 let layer_n = key[first_layer_idx + 1..last_layer_idx]
1033 .parse::<usize>()
1034 .unwrap();
1035 layers_clone.contains(&layer_n) || layers_clone.is_empty()
1036 } else {
1037 false
1038 }
1039 },
1040 Arc::new(|_| DeviceForLoadTensor::Base),
1041 )?;
1042 vbs.push(vb);
1043 }
1044
1045 let gate_vb = if let Some(gate_model_id) = gate_model_id {
1046 let model_id_str = &gate_model_id;
1047 let model_id = Path::new(&gate_model_id);
1048
1049 let api = {
1050 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1051 let mut api = ApiBuilder::from_cache(cache)
1052 .with_progress(!silent)
1053 .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1054 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1055 api = api.with_cache_dir(x.into());
1056 }
1057 api.build().map_err(candle_core::Error::msg)?
1058 };
1059 let revision = revision.clone().unwrap_or("main".to_string());
1060 let api = api.repo(Repo::with_revision(
1061 model_id_str.clone(),
1062 RepoType::Model,
1063 revision.clone(),
1064 ));
1065
1066 let mut gate_filenames = vec![];
1067 for rfilename in
1068 api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1069 {
1070 gate_filenames.push(api_get_file!(api, &rfilename, model_id));
1071 }
1072 assert_eq!(
1073 gate_filenames.len(),
1074 1,
1075 "Gate model ID must contain only one .safetensors file"
1076 );
1077
1078 let vb = from_mmaped_safetensors(
1079 gate_filenames.clone(),
1080 vec![],
1081 Some(dtype),
1082 dev,
1083 vec![None],
1084 silent,
1085 None,
1086 |_| true,
1087 Arc::new(|_| DeviceForLoadTensor::Base),
1088 )?;
1089 info!(
1090 "Loaded gating layers from `{}`",
1091 gate_filenames[0].display()
1092 );
1093 Some(vb)
1094 } else {
1095 None
1096 };
1097
1098 self.model
1099 .create_anymoe_layers(vbs, config, (prefix, mlp), layers, expert_type, gate_vb)
1100 }
1101 fn amoe_supported(&self) -> bool {
1102 self.model.amoe_supported()
1103 }
1104}