1use super::isq::ImatrixDataSource;
2use super::isq::UqffFullSer;
3use super::{
4 get_model_paths, get_xlora_paths, AdapterKind, AnyMoePipelineMixin, AutoVisionLoader,
5 CacheManager, CacheManagerMixin, EitherCache, ForwardInputsResult, Gemma3Loader,
6 GeneralMetadata, IsqPipelineMixin, Loader, MetadataMixin, MiniCpmOLoader, ModelCategory,
7 ModelKind, ModelPaths, MultimodalPromptPrefixer, Phi4MMLoader, PreProcessingMixin, Processor,
8 Qwen2VLLoader, TokenSource, VLlama4Loader, VLlamaLoader, VisionModel, VisionModelLoader,
9};
10use super::{
11 Gemma3nLoader, Idefics2Loader, Idefics3Loader, LLaVALoader, LLaVANextLoader, Mistral3Loader,
12 Phi3VLoader, Qwen2_5VLLoader, VisionLoaderType,
13};
14use crate::attention::ATTENTION_CHUNK_SIZE;
15use crate::device_map::{self, DeviceMapper};
16use crate::distributed::{self, WorkerTransferData};
17use crate::kv_cache::{FullCacheManager, NormalCacheManager};
18use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine};
19use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
20use crate::pipeline::llg::build_llg_factory;
21use crate::pipeline::loaders::auto_device_map;
22use crate::pipeline::loaders::QuantizationConfigShim;
23use crate::pipeline::sampling::sample_and_add_toks;
24use crate::pipeline::text_models_inputs_processor::make_prompt_chunk;
25use crate::pipeline::{get_chat_template, ChatTemplate, IsqOrganization, LocalModelPaths};
26use crate::prefix_cacher::PrefixCacheManagerV2;
27use crate::sequence::Sequence;
28use crate::utils::tokenizer::get_tokenizer;
29use crate::utils::varbuilder_utils::DeviceForLoadTensor;
30use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
31use crate::vision_models::preprocessor_config::PreProcessorConfig;
32use crate::vision_models::processor_config::ProcessorConfig;
33use crate::vision_models::ModelInputs;
34use crate::{
35 api_dir_list, api_get_file, get_paths, get_uqff_paths, vision_normal_model_loader,
36 vision_normal_model_loader_sharded, AnyMoeExpertType, DeviceMapSetting, Ordering,
37 PagedAttentionConfig, Pipeline, Topology, TryIntoDType, GLOBAL_HF_CACHE,
38};
39use anyhow::Result;
40use candle_core::{Device, Tensor, Var};
41use hf_hub::Cache;
42use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
43use indicatif::MultiProgress;
44use mistralrs_quant::log::once_log_info;
45use mistralrs_quant::{AfqLayer, GgufMatMul, HqqLayer, IsqType, QuantizedSerdeType};
46use rand_isaac::Isaac64Rng;
47use regex_automata::meta::Regex;
48use std::any::Any;
49use std::borrow::Cow;
50use std::path::{Path, PathBuf};
51use std::str::FromStr;
52use std::sync::{Arc, RwLock};
53use std::time::Instant;
54use std::{env, fs};
55use tokenizers::Tokenizer;
56use tokio::sync::Mutex;
57use tracing::{info, warn};
58
59pub struct VisionPipeline {
60 model: Box<dyn VisionModel + Send + Sync>,
61 tokenizer: Arc<Tokenizer>,
62 chat_template: Arc<ChatTemplate>,
63 model_id: String,
64 metadata: Arc<GeneralMetadata>,
65 processor: Arc<dyn Processor + Send + Sync>,
66 preprocessor_config: Arc<PreProcessorConfig>,
67 topology: Option<Topology>,
68 silent: bool,
69 prefixer: Arc<dyn MultimodalPromptPrefixer>,
70 mapper: Box<dyn DeviceMapper + Send + Sync>,
71
72 template_filename: Option<PathBuf>,
74 generation_config: Option<PathBuf>,
75 config: String,
76 processor_filename: Option<PathBuf>,
77 preprocessor_filename: Option<PathBuf>,
78 imatrix: Option<PathBuf>,
79}
80
81pub struct VisionLoader {
83 inner: Box<dyn VisionModelLoader>,
84 model_id: String,
85 config: VisionSpecificConfig,
86 kind: ModelKind,
87 chat_template: Option<String>,
88 tokenizer_json: Option<String>,
89 xlora_model_id: Option<String>,
90 xlora_order: Option<Ordering>,
91 token_source: RwLock<Option<TokenSource>>,
92 revision: RwLock<Option<String>>,
93 from_uqff: RwLock<Option<Vec<PathBuf>>>,
94 jinja_explicit: Option<String>,
95 hf_cache_path: Option<PathBuf>,
96 lora_adapter_ids: Option<Vec<String>>,
97}
98
99#[derive(Default)]
100pub struct VisionLoaderBuilder {
102 model_id: Option<String>,
103 config: VisionSpecificConfig,
104 kind: ModelKind,
105 chat_template: Option<String>,
106 tokenizer_json: Option<String>,
107 jinja_explicit: Option<String>,
108 hf_cache_path: Option<PathBuf>,
109 lora_adapter_ids: Option<Vec<String>>,
110}
111
112#[derive(Clone, Default)]
113pub struct VisionSpecificConfig {
115 pub topology: Option<Topology>,
116 pub write_uqff: Option<PathBuf>,
117 pub from_uqff: Option<Vec<PathBuf>>,
118 pub max_edge: Option<u32>,
119 pub imatrix: Option<PathBuf>,
120 pub calibration_file: Option<PathBuf>,
121 pub hf_cache_path: Option<PathBuf>,
122 pub matformer_config_path: Option<PathBuf>,
123 pub matformer_slice_name: Option<String>,
124}
125
126impl VisionLoaderBuilder {
127 pub fn new(
128 config: VisionSpecificConfig,
129 chat_template: Option<String>,
130 tokenizer_json: Option<String>,
131 model_id: Option<String>,
132 jinja_explicit: Option<String>,
133 ) -> Self {
134 Self {
135 config,
136 chat_template,
137 tokenizer_json,
138 model_id,
139 jinja_explicit,
140 kind: ModelKind::Normal,
141 hf_cache_path: None,
142 ..Default::default()
143 }
144 }
145
146 pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
147 self.hf_cache_path = Some(hf_cache_path);
148 self
149 }
150
151 pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
152 self.kind = ModelKind::Adapter {
153 adapter: AdapterKind::Lora,
154 };
155 self.lora_adapter_ids = Some(lora_adapter_ids);
156 self
157 }
158
159 pub fn build(self, loader: Option<VisionLoaderType>) -> Box<dyn Loader> {
160 let loader: Box<dyn VisionModelLoader> = match loader {
161 Some(VisionLoaderType::Phi3V) => Box::new(Phi3VLoader),
162 Some(VisionLoaderType::Idefics2) => Box::new(Idefics2Loader),
163 Some(VisionLoaderType::LLaVANext) => Box::new(LLaVANextLoader),
164 Some(VisionLoaderType::LLaVA) => Box::new(LLaVALoader),
165 Some(VisionLoaderType::VLlama) => Box::new(VLlamaLoader),
166 Some(VisionLoaderType::Qwen2VL) => Box::new(Qwen2VLLoader),
167 Some(VisionLoaderType::Idefics3) => Box::new(Idefics3Loader),
168 Some(VisionLoaderType::MiniCpmO) => Box::new(MiniCpmOLoader),
169 Some(VisionLoaderType::Phi4MM) => Box::new(Phi4MMLoader),
170 Some(VisionLoaderType::Qwen2_5VL) => Box::new(Qwen2_5VLLoader),
171 Some(VisionLoaderType::Gemma3) => Box::new(Gemma3Loader),
172 Some(VisionLoaderType::Mistral3) => Box::new(Mistral3Loader),
173 Some(VisionLoaderType::Llama4) => Box::new(VLlama4Loader),
174 Some(VisionLoaderType::Gemma3n) => Box::new(Gemma3nLoader),
175 None => Box::new(AutoVisionLoader),
176 };
177 Box::new(VisionLoader {
178 inner: loader,
179 model_id: self.model_id.unwrap(),
180 config: self.config,
181 kind: self.kind,
182 chat_template: self.chat_template,
183 tokenizer_json: self.tokenizer_json,
184 xlora_model_id: None,
185 xlora_order: None,
186 jinja_explicit: self.jinja_explicit,
187 token_source: RwLock::new(None),
188 revision: RwLock::new(None),
189 from_uqff: RwLock::new(None),
190 hf_cache_path: self.hf_cache_path,
191 lora_adapter_ids: self.lora_adapter_ids,
192 })
193 }
194}
195
196impl Loader for VisionLoader {
197 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
198 fn load_model_from_hf(
199 &self,
200 revision: Option<String>,
201 token_source: TokenSource,
202 dtype: &dyn TryIntoDType,
203 device: &Device,
204 silent: bool,
205 mapper: DeviceMapSetting,
206 in_situ_quant: Option<IsqType>,
207 paged_attn_config: Option<PagedAttentionConfig>,
208 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
209 let cache = self
210 .hf_cache_path
211 .clone()
212 .map(Cache::new)
213 .unwrap_or_default();
214 GLOBAL_HF_CACHE.get_or_init(|| cache);
215
216 let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths!(
217 LocalModelPaths,
218 &token_source,
219 revision.clone(),
220 self,
221 None,
222 None,
223 silent,
224 self.config.from_uqff.is_some()
225 );
226 if let Some(from_uqff) = self.config.from_uqff.clone() {
227 *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
228 }
229 *self
230 .token_source
231 .write()
232 .expect("Failed to write to token source") = Some(token_source);
233 *self.revision.write().expect("Failed to write to revision") = revision;
234 self.load_model_from_path(
235 &paths?,
236 dtype,
237 device,
238 silent,
239 mapper,
240 in_situ_quant,
241 paged_attn_config,
242 )
243 }
244
245 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
246 fn load_model_from_path(
247 &self,
248 paths: &Box<dyn ModelPaths>,
249 dtype: &dyn TryIntoDType,
250 device: &Device,
251 silent: bool,
252 mut mapper: DeviceMapSetting,
253 mut in_situ_quant: Option<IsqType>,
254 mut paged_attn_config: Option<PagedAttentionConfig>,
255 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
256 let config = std::fs::read_to_string(paths.get_config_filename())?;
257
258 if !self.inner.supports_paged_attention(&config) {
259 paged_attn_config = None;
260 }
261
262 info!("Prompt chunk size is {ATTENTION_CHUNK_SIZE}.");
263
264 let use_nccl = mistralrs_quant::distributed::use_nccl();
265
266 let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
267 let payload: WorkerTransferData = serde_json::from_str(&payload)?;
268 let WorkerTransferData::Init { id: _, worker_rank } = payload;
269 vec![candle_core::Device::new_cuda_with_stream(worker_rank + 1)?]
270 } else if use_nccl {
271 vec![candle_core::Device::new_cuda_with_stream(0)?]
272 } else {
273 device_map::get_all_similar_devices(device)?
274 };
275 #[cfg(feature = "cuda")]
276 for device in &available_devices {
277 if let Device::Cuda(dev) = device {
278 unsafe { dev.disable_event_tracking() };
279 }
280 }
281 let device = if use_nccl {
282 available_devices[0].clone()
283 } else {
284 device.clone()
285 };
286
287 let matformer_slicing_config = if let Some(matformer_path) =
289 &self.config.matformer_config_path
290 {
291 use crate::matformer::{MatformerConfig, MatformerSliceConfig};
292 info!("Loading Matformer config from {:?}", matformer_path);
293 let config = Arc::new(MatformerConfig::from_file(matformer_path)?);
294
295 if let Some(slice_name) = &self.config.matformer_slice_name {
296 info!("Using Matformer slice: {}", slice_name);
297 Some(MatformerSliceConfig::new(slice_name.clone(), config))
298 } else {
299 warn!("Matformer config loaded but no slice name specified. Models will use their default slice.");
302 None
303 }
304 } else {
305 None
306 };
307
308 if use_nccl {
310 mapper = DeviceMapSetting::DummyNccl {
311 nm_device: available_devices[0].clone(),
312 };
313 } else if let DeviceMapSetting::Auto(mut params) = mapper.clone() {
314 params = params.maybe_promote_to_vision();
316
317 let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
319
320 if QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)? != 1 {
322 in_situ_quant = None;
323 }
324
325 let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
328 if let Some(serialized) = &*self.from_uqff.read().unwrap() {
329 let weight_pack_factor = {
330 let ser_artifacts = unsafe {
331 candle_core::safetensors::MmapedSafetensors::multi(serialized)?
332 };
333 let mut total_pack_factors = 0;
334 let total_tensors = ser_artifacts.tensors().len();
335 for (_, artifact) in ser_artifacts.tensors() {
336 let artifact = artifact.data();
337 let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
339 let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
340 {
341 QuantizedSerdeType::Hqq => {
342 HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
343 .pack_factor(dtype)
344 }
345 QuantizedSerdeType::Gguf => {
346 GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
347 .pack_factor(dtype)
348 }
349 QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
350 QuantizedSerdeType::Unquant => 1,
351 QuantizedSerdeType::Afq => {
352 AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
353 .pack_factor(dtype)
354 }
355 };
356 total_pack_factors += pack_factor;
357 }
358
359 total_pack_factors / total_tensors
360 };
361
362 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
363 &config,
364 dtype,
365 weight_pack_factor,
366 matformer_slicing_config.as_ref(),
367 )?;
368 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
369 &config,
370 dtype,
371 weight_pack_factor,
372 matformer_slicing_config.as_ref(),
373 )?;
374 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
375 (
376 layer_sizes_in_bytes,
377 non_mapped_size_in_bytes,
378 layer_sizes_sum + non_mapped_size_in_bytes,
379 )
380 } else if let Some(isq) = in_situ_quant {
381 let weight_pack_factor = isq.pack_factor(dtype);
382 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
383 &config,
384 dtype,
385 weight_pack_factor,
386 matformer_slicing_config.as_ref(),
387 )?;
388 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
389 &config,
390 dtype,
391 weight_pack_factor,
392 matformer_slicing_config.as_ref(),
393 )?;
394 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
395 (
396 layer_sizes_in_bytes,
397 non_mapped_size_in_bytes,
398 layer_sizes_sum + non_mapped_size_in_bytes,
399 )
400 } else {
401 let weight_pack_factor =
403 QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)?;
404 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
405 &config,
406 dtype,
407 weight_pack_factor,
408 matformer_slicing_config.as_ref(),
409 )?;
410 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
411 &config,
412 dtype,
413 weight_pack_factor,
414 matformer_slicing_config.as_ref(),
415 )?;
416 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
417 (
418 layer_sizes_in_bytes,
419 non_mapped_size_in_bytes,
420 layer_sizes_sum + non_mapped_size_in_bytes,
421 )
422 };
423
424 let new = auto_device_map::get_device_layers(
425 &*self.inner,
426 &config,
427 self.inner.num_layers(&config)?,
428 layer_sizes_in_bytes,
429 non_mapped_size_in_bytes,
430 total_model_size_in_bytes,
431 &available_devices,
432 dtype,
433 ¶ms,
434 paged_attn_config.as_ref(),
435 )?;
436 mapper = DeviceMapSetting::Map(new);
437 }
438
439 let pipeline_mapper = mapper.into_mapper(
440 self.inner.num_layers(&config)?,
441 &device,
442 self.config.topology.as_ref(),
443 )?;
444 let mapper = mapper.into_mapper(
445 self.inner.num_layers(&config)?,
446 &device,
447 self.config.topology.as_ref(),
448 )?;
449 let mut layer_devices = Vec::new();
450 for layer in 0..self.inner.num_layers(&config)? {
451 let device = mapper.device_for(layer, false).cloned();
452 layer_devices.push(device);
453 }
454 let dtype = mapper.get_min_dtype(dtype)?;
455
456 let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
459 if mapping_uses_cpu && paged_attn_config.is_some() {
460 warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
461 paged_attn_config = None;
462 }
463
464 info!("Model config: {:?}", self.inner.get_config_repr(&config)?);
465 if crate::using_flash_attn() {
466 once_log_info("FlashAttention is enabled.");
467 }
468
469 let mut loading_isq = if self.config.imatrix.is_none()
471 && self.config.calibration_file.is_none()
472 && !device.is_cuda()
473 && self.config.write_uqff.is_none()
474 && in_situ_quant.is_some()
475 {
476 let predicates = self.inner.immediate_isq_predicates(&config)?;
477 info!("Applying ISQ to {in_situ_quant:?}");
478 if predicates.is_empty() {
479 warn!("No predicates for this model and ISQ setting detected. ISQ will not be applied to any weights!");
480 }
481 mistralrs_quant::set_immediate_isq(in_situ_quant, predicates);
482 false
483 } else {
484 in_situ_quant.is_some()
485 };
486
487 if let Some(ref topology) = self.config.topology {
488 loading_isq |= topology
489 .0
490 .iter()
491 .any(|layer| layer.as_ref().is_some_and(|layer| layer.isq.is_some()));
492 }
493
494 if self.config.imatrix.is_some() && self.config.calibration_file.is_some() {
495 anyhow::bail!(
496 "`imatrix` and `calibration_file` were both specified, this is not allowed."
497 );
498 }
499
500 let load_device = if !loading_isq || self.config.calibration_file.is_some() {
502 loading_isq = false;
503 device.clone()
504 } else {
505 Device::Cpu
506 };
507
508 let attention_mechanism = if paged_attn_config.is_some() {
509 AttentionImplementation::PagedAttention
510 } else {
511 AttentionImplementation::Eager
512 };
513
514 let multi_progress = Arc::new(MultiProgress::new());
515
516 let mut model = if use_nccl {
517 let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
518 dtype,
519 &device,
520 &available_devices,
521 silent,
522 &config,
523 loading_isq,
524 self.config.from_uqff.is_some(),
525 IsqOrganization::Default,
526 &*self.inner,
527 paths.as_ref(),
528 )?;
529
530 match self.kind {
532 ModelKind::Normal => vision_normal_model_loader_sharded!(
533 sharded_vb,
534 config,
535 self.inner,
536 mapper,
537 loading_isq,
538 device.clone(),
539 attention_mechanism,
540 multi_progress.clone(),
541 matformer_slicing_config.clone(),
542 ),
543 _ => unreachable!(),
544 }
545 } else {
546 match self.kind {
547 ModelKind::Normal => vision_normal_model_loader!(
548 paths,
549 Some(dtype),
550 &load_device,
551 layer_devices.clone(),
552 config,
553 self.inner,
554 silent,
555 mapper,
556 loading_isq,
557 self.config.from_uqff.is_some(),
558 device.clone(),
559 attention_mechanism,
560 multi_progress,
561 matformer_slicing_config.clone(),
562 ),
563 _ => unreachable!(),
564 }
565 };
566
567 let preprocessor_config: PreProcessorConfig = match paths.get_preprocessor_config().as_ref()
569 {
570 Some(preprocessor_config) => {
571 serde_json::from_str(&fs::read_to_string(preprocessor_config).unwrap()).unwrap()
572 }
573 None => PreProcessorConfig::default(),
574 };
575 let processor_config: Option<ProcessorConfig> = paths
576 .get_processor_config()
577 .as_ref()
578 .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
579
580 let processor = self.inner.get_processor(
581 &config,
582 processor_config,
583 preprocessor_config.clone(),
584 self.config.max_edge,
585 ); let tokenizer = get_tokenizer(
588 paths.get_tokenizer_filename(),
589 Some(processor.get_special_tokens()),
590 )?;
591
592 let gen_conf: Option<GenerationConfig> = paths
593 .get_gen_conf_filename()
594 .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
595 let chat_template_explicit = paths
596 .get_chat_template_explicit()
597 .as_ref()
598 .map(|x| x.to_string_lossy().to_string());
599 let chat_template = get_chat_template(
600 paths,
601 self.jinja_explicit.as_ref(),
602 chat_template_explicit.as_ref(),
603 self.chat_template.as_ref(),
604 None,
605 );
606
607 if let Some(calibration_file) = &self.config.calibration_file {
608 let calibration_data = std::fs::read_to_string(calibration_file)?;
609 let tokens = tokenizer
611 .encode_fast(calibration_data, false)
612 .map_err(anyhow::Error::msg)?
613 .get_ids()
614 .to_vec();
615 info!(
616 "Collecting imatrix from calibration file `{}` of {} tokens.",
617 calibration_file.display(),
618 tokens.len()
619 );
620 let bos_toks = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();
621 let bos_tok_id = tokenizer
622 .token_to_id(&bos_toks[0])
623 .expect("Somehow the bos token is not present.");
624
625 model.begin_track_stats()?;
628
629 const CHUNK_SIZE: usize = 1024;
630 let n_chunks: usize = tokens.len().div_ceil(CHUNK_SIZE);
631 let start = Instant::now();
632 for (i, chunk) in tokens.chunks(CHUNK_SIZE).enumerate() {
633 let chunk = [vec![bos_tok_id], chunk.to_vec()].concat();
634 let chunk_len = chunk.len();
635
636 let start = Instant::now();
637 let inputs = make_prompt_chunk(
638 0,
639 vec![&chunk],
640 &[0],
641 &load_device,
642 None,
643 false,
644 None,
645 None,
646 )?;
647 let _ = model.forward(
648 &inputs.input,
649 None, &inputs.positions,
651 inputs.context_lens,
652 inputs.position_ids,
653 model.default_model_specific_args(&inputs.input),
654 None,
655 &inputs.flash_meta,
656 )?;
657 match model.cache_mut() {
658 EitherCache::Full(full) => {
659 for layer in &mut *full.lock() {
660 *layer = None
661 }
662 }
663 EitherCache::Normal(normal) => {
664 for layer in &mut *normal.lock().unwrap().0 {
665 layer.reset();
666 }
667 }
668 }
669 let end = Instant::now();
670 info!(
671 "Processed chunk {}/{n_chunks} ({chunk_len} tokens), {:.2}s",
672 i + 1,
673 end.duration_since(start).as_secs_f32()
674 );
675 }
676 load_device.synchronize()?;
677 let end = Instant::now();
678 info!(
679 "Finished collecting imatrix in {:.2}s",
680 end.duration_since(start).as_secs_f32()
681 );
682 }
683
684 if (loading_isq || self.config.topology.is_some()) && self.config.from_uqff.is_none() {
686 let imatrix_source = match (
687 self.config.imatrix.as_ref(),
688 self.config.calibration_file.is_some(),
689 ) {
690 (None, false) => None,
691 (Some(file), false) => Some(ImatrixDataSource::File(file)),
692 (None, true) => Some(ImatrixDataSource::Collected),
693 (Some(_), true) => unreachable!(),
694 };
695 model.quantize(
696 in_situ_quant,
697 device.clone(),
698 self.config.topology.as_ref(),
699 silent,
700 imatrix_source,
701 IsqOrganization::Default,
702 self.config.write_uqff.as_ref(),
703 UqffFullSer {
704 tokenizer: &tokenizer,
705 template_filename: paths.get_template_filename(),
706 generation_config: paths.get_gen_conf_filename(),
707 config: config.clone(),
708 processor_filename: paths.get_processor_config(),
709 preprocessor_filename: paths.get_preprocessor_config(),
710 },
711 Arc::new(MultiProgress::new()),
712 )?;
713 } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
714 model.load_from_artifacts(
715 device.clone(),
716 self.config.topology.as_ref(),
717 silent,
718 from_uqff,
719 )?;
720 }
721
722 let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
723 anyhow::ensure!(
724 !matches!(self.kind, ModelKind::Adapter { .. }),
725 "PagedAttention does not support adapter models."
726 );
727 let cache_config = calculate_cache_config(
728 paged_attn_config.mem_gpu,
729 paged_attn_config.mem_cpu,
730 paged_attn_config.block_size,
731 dtype,
732 paged_attn_config.cache_type,
733 model.config(),
734 &device,
735 &layer_devices,
736 silent,
737 )?;
738 let cache_engine =
739 CacheEngine::new(model.config(), &cache_config, dtype, &device, layer_devices)?;
740 (Some(cache_config), Some(cache_engine))
741 } else {
742 (None, None)
743 };
744
745 let max_seq_len = model.max_seq_len();
746 let llg_factory = build_llg_factory(tokenizer.clone())?;
747 let num_hidden_layers = match model.cache() {
748 EitherCache::Full(full) => full.lock().len(),
749 EitherCache::Normal(normal) => normal.lock().unwrap().0.len(),
750 };
751 let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
752 let sliding_window = model.config().sliding_window;
753 let model_metadata = Arc::new(model.config().clone());
754 Ok(Arc::new(Mutex::new(VisionPipeline {
755 model,
756 tokenizer: tokenizer.into(),
757 chat_template: Arc::new(chat_template),
758 model_id: self.model_id.clone(),
759 metadata: Arc::new(GeneralMetadata {
760 max_seq_len,
761 llg_factory: Some(llg_factory),
762 is_xlora: false,
763 num_hidden_layers,
764 eos_tok: eos,
765 kind: self.kind.clone(),
766 no_kv_cache: false,
767 no_prefix_cache: !self.inner.supports_prefix_cacher(&config),
768 activation_dtype: dtype,
769 sliding_window,
770 cache_config,
771 cache_engine,
772 model_metadata: Some(model_metadata),
773 modalities: self.inner.modalities(&config)?,
774 }),
775 processor,
776 prefixer: self.inner.prefixer(&config),
777 preprocessor_config: Arc::new(preprocessor_config),
778 topology: self.config.topology.clone(),
779 silent,
780 template_filename: paths.get_template_filename().clone(),
781 generation_config: paths.get_gen_conf_filename().cloned(),
782 config,
783 processor_filename: paths.get_processor_config().clone(),
784 preprocessor_filename: paths.get_preprocessor_config().clone(),
785 mapper: pipeline_mapper,
786 imatrix: self.config.imatrix.clone(),
787 })))
788 }
789
790 fn get_id(&self) -> String {
791 self.model_id.to_string()
792 }
793
794 fn get_kind(&self) -> ModelKind {
795 self.kind.clone()
796 }
797}
798
799impl PreProcessingMixin for VisionPipeline {
800 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
801 Some(self.chat_template.clone())
802 }
803 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
804 Some(self.preprocessor_config.clone())
805 }
806 fn get_processor(&self) -> Arc<dyn super::Processor> {
807 self.processor.clone()
808 }
809}
810
811impl IsqPipelineMixin for VisionPipeline {
812 fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
813 let device = self.device().clone();
814 self.model
815 .quantize(
816 Some(dtype),
817 device,
818 self.topology.as_ref(),
819 self.silent,
820 self.imatrix.as_ref().map(ImatrixDataSource::File),
821 IsqOrganization::Default,
822 None,
823 UqffFullSer {
824 tokenizer: &self.tokenizer,
825 template_filename: &self.template_filename,
826 generation_config: self.generation_config.as_ref(),
827 config: self.config.clone(),
828 processor_filename: &self.processor_filename,
829 preprocessor_filename: &self.preprocessor_filename,
830 },
831 Arc::new(MultiProgress::new()),
832 )
833 .map_err(anyhow::Error::msg)
834 }
835}
836
837impl CacheManagerMixin for VisionPipeline {
838 fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
839 if matches!(self.model.cache(), EitherCache::Full(_)) {
840 FullCacheManager.clone_in_cache(self, seqs, false)
841 } else {
842 NormalCacheManager.clone_in_cache(self, seqs, false)
843 }
844 }
845 fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
846 if matches!(self.model.cache(), EitherCache::Full(_)) {
847 FullCacheManager.clone_out_cache(self, seqs, false)
848 } else {
849 NormalCacheManager.clone_out_cache(self, seqs, false)
850 }
851 }
852 fn set_none_cache(
853 &self,
854 seqs: &mut [&mut Sequence],
855 reset_non_granular: bool,
856 modify_draft_cache: bool,
857
858 load_preallocated_cache: bool,
859 ) {
860 if matches!(self.model.cache(), EitherCache::Full(_)) {
861 FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false);
862 } else {
863 NormalCacheManager.set_none_cache(
864 self,
865 seqs,
866 modify_draft_cache,
867 load_preallocated_cache,
868 );
869 }
870 if reset_non_granular {
871 self.reset_non_granular_state()
872 }
873 }
874 fn cache(&self) -> &EitherCache {
875 self.model.cache()
876 }
877}
878
879impl MetadataMixin for VisionPipeline {
880 fn device(&self) -> Device {
881 self.model.device().clone()
882 }
883 fn get_metadata(&self) -> Arc<GeneralMetadata> {
884 self.metadata.clone()
885 }
886 fn name(&self) -> String {
887 self.model_id.clone()
888 }
889 fn reset_non_granular_state(&self) {}
890 fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
891 Some(self.tokenizer.clone())
892 }
893 fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
894 Some(&*self.mapper)
895 }
896}
897
898#[async_trait::async_trait]
899impl Pipeline for VisionPipeline {
900 fn forward_inputs(
901 &mut self,
902 inputs: Box<dyn Any>,
903 return_raw_logits: bool,
904 ) -> candle_core::Result<ForwardInputsResult> {
905 let ModelInputs {
906 input_ids,
907 seqlen_offsets,
908 context_lens,
909 position_ids,
910 pixel_values,
911 model_specific_args,
912 paged_attn_meta,
913 flash_meta,
914 } = *inputs.downcast::<ModelInputs>().expect("Downcast failed.");
915 let metadata = self.get_metadata();
916 let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
917 (Some(engine), Some(meta)) => Some((engine.get_kv_cache().clone(), meta)),
918 (Some(_), None) => {
919 candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
921 }
922 (None, Some(_)) => {
923 candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
925 }
926 (None, None) => None,
927 };
928 let logits = self.model.forward(
929 &input_ids,
930 pixel_values,
931 &seqlen_offsets,
932 context_lens,
933 position_ids,
934 model_specific_args,
935 paged_attn_meta,
936 &flash_meta,
937 )?;
938 if return_raw_logits {
939 Ok(ForwardInputsResult::RawLogits { logits })
940 } else {
941 Ok(ForwardInputsResult::CausalGeneration { logits })
942 }
943 }
944 async fn sample_causal_gen(
945 &self,
946 seqs: &mut [&mut Sequence],
947 logits: Vec<Tensor>,
948 prefix_cacher: &mut PrefixCacheManagerV2,
949 disable_eos_stop: bool,
950 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
951 ) -> Result<(), candle_core::Error> {
952 sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
953 }
954 fn category(&self) -> ModelCategory {
955 ModelCategory::Vision {
956 prefixer: self.prefixer.clone(),
957 }
958 }
959}
960
961impl AnyMoePipelineMixin for VisionPipeline {
962 fn amoe_finish_training(&mut self, gate_model_id: Option<String>) -> candle_core::Result<()> {
963 self.model.finish_training(gate_model_id)
964 }
965 fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
966 self.model.get_vars()
967 }
968 fn amoe_base_model_trainable_params(&self) -> usize {
969 self.model.trainable_params()
970 }
971 fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
972 self.model.take_cached_gating_outputs()
973 }
974 fn amoe_create_layers(
975 &mut self,
976 model_ids: Vec<String>,
977 token: &TokenSource,
978 revision: Option<String>,
979 match_regex: &str,
980 config: crate::amoe::AnyMoeConfig,
981 dtype: candle_core::DType,
982 dev: &Device,
983 (prefix, mlp): (String, String),
984 layers: Vec<usize>,
985 expert_type: AnyMoeExpertType,
986 silent: bool,
987 gate_model_id: Option<String>,
988 ) -> candle_core::Result<()> {
989 let mut vbs = Vec::new();
990 let regex = Regex::new(match_regex).map_err(candle_core::Error::msg)?;
992 for model_id in model_ids {
993 let model_id_str = &model_id;
994 let model_id = Path::new(&model_id);
995
996 let api = {
997 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
998 let mut api = ApiBuilder::from_cache(cache)
999 .with_progress(!silent)
1000 .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1001 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1002 api = api.with_cache_dir(x.into());
1003 }
1004 api.build().map_err(candle_core::Error::msg)?
1005 };
1006 let revision = revision.clone().unwrap_or("main".to_string());
1007 let api = api.repo(Repo::with_revision(
1008 model_id_str.clone(),
1009 RepoType::Model,
1010 revision.clone(),
1011 ));
1012
1013 let mut filenames = vec![];
1014 for rfilename in
1015 api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1016 {
1017 filenames.push(api_get_file!(api, &rfilename, model_id));
1018 }
1019
1020 let regex = regex.clone();
1021 let match_regex_clone = match_regex.to_string();
1022 let layers_clone = layers.clone();
1023 let vb = from_mmaped_safetensors(
1024 filenames,
1025 vec![],
1026 Some(dtype),
1027 dev,
1028 vec![None],
1029 silent,
1030 None,
1031 move |key| {
1032 if regex.is_match(&key) {
1033 let last_layer_idx = key.find(&match_regex_clone).unwrap() - 1;
1036 let first_layer_idx = key[..last_layer_idx].rfind('.').unwrap();
1037 let layer_n = key[first_layer_idx + 1..last_layer_idx]
1038 .parse::<usize>()
1039 .unwrap();
1040 layers_clone.contains(&layer_n) || layers_clone.is_empty()
1041 } else {
1042 false
1043 }
1044 },
1045 Arc::new(|_| DeviceForLoadTensor::Base),
1046 )?;
1047 vbs.push(vb);
1048 }
1049
1050 let gate_vb = if let Some(gate_model_id) = gate_model_id {
1051 let model_id_str = &gate_model_id;
1052 let model_id = Path::new(&gate_model_id);
1053
1054 let api = {
1055 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1056 let mut api = ApiBuilder::from_cache(cache)
1057 .with_progress(!silent)
1058 .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1059 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1060 api = api.with_cache_dir(x.into());
1061 }
1062 api.build().map_err(candle_core::Error::msg)?
1063 };
1064 let revision = revision.clone().unwrap_or("main".to_string());
1065 let api = api.repo(Repo::with_revision(
1066 model_id_str.clone(),
1067 RepoType::Model,
1068 revision.clone(),
1069 ));
1070
1071 let mut gate_filenames = vec![];
1072 for rfilename in
1073 api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1074 {
1075 gate_filenames.push(api_get_file!(api, &rfilename, model_id));
1076 }
1077 assert_eq!(
1078 gate_filenames.len(),
1079 1,
1080 "Gate model ID must contain only one .safetensors file"
1081 );
1082
1083 let vb = from_mmaped_safetensors(
1084 gate_filenames.clone(),
1085 vec![],
1086 Some(dtype),
1087 dev,
1088 vec![None],
1089 silent,
1090 None,
1091 |_| true,
1092 Arc::new(|_| DeviceForLoadTensor::Base),
1093 )?;
1094 info!(
1095 "Loaded gating layers from `{}`",
1096 gate_filenames[0].display()
1097 );
1098 Some(vb)
1099 } else {
1100 None
1101 };
1102
1103 self.model
1104 .create_anymoe_layers(vbs, config, (prefix, mlp), layers, expert_type, gate_vb)
1105 }
1106 fn amoe_supported(&self) -> bool {
1107 self.model.amoe_supported()
1108 }
1109}