1use super::isq::ImatrixDataSource;
2use super::isq::UqffFullSer;
3use super::{
4 get_model_paths, get_xlora_paths, AdapterKind, AnyMoePipelineMixin, AutoVisionLoader,
5 CacheManager, CacheManagerMixin, EitherCache, ForwardInputsResult, Gemma3Loader,
6 GeneralMetadata, IsqPipelineMixin, Loader, MetadataMixin, MiniCpmOLoader, ModelCategory,
7 ModelKind, ModelPaths, MultimodalPromptPrefixer, Phi4MMLoader, PreProcessingMixin, Processor,
8 Qwen2VLLoader, Qwen3VLLoader, TokenSource, VLlama4Loader, VLlamaLoader, VisionModel,
9 VisionModelLoader,
10};
11use super::{
12 Gemma3nLoader, Idefics2Loader, Idefics3Loader, LLaVALoader, LLaVANextLoader, Mistral3Loader,
13 Phi3VLoader, Qwen2_5VLLoader, VisionLoaderType,
14};
15use crate::attention::ATTENTION_CHUNK_SIZE;
16use crate::device_map::{self, DeviceMapper};
17use crate::distributed::{self, WorkerTransferData};
18use crate::kv_cache::{FullCacheManager, NormalCacheManager};
19use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine};
20use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
21use crate::pipeline::llg::build_llg_factory;
22use crate::pipeline::loaders::auto_device_map;
23use crate::pipeline::loaders::QuantizationConfigShim;
24use crate::pipeline::sampling::sample_and_add_toks;
25use crate::pipeline::text_models_inputs_processor::make_prompt_chunk;
26use crate::pipeline::{get_chat_template, ChatTemplate, IsqOrganization, LocalModelPaths};
27use crate::prefix_cacher::PrefixCacheManagerV2;
28use crate::sequence::Sequence;
29use crate::utils::tokenizer::get_tokenizer;
30use crate::utils::varbuilder_utils::DeviceForLoadTensor;
31use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
32use crate::vision_models::preprocessor_config::PreProcessorConfig;
33use crate::vision_models::processor_config::ProcessorConfig;
34use crate::vision_models::ModelInputs;
35use crate::{
36 api_dir_list, api_get_file, get_paths, get_uqff_paths, vision_normal_model_loader,
37 vision_normal_model_loader_sharded, AnyMoeExpertType, DeviceMapSetting, Ordering,
38 PagedAttentionConfig, Pipeline, Topology, TryIntoDType, GLOBAL_HF_CACHE,
39};
40use anyhow::Result;
41use candle_core::{Device, Tensor, Var};
42use hf_hub::Cache;
43use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
44use indicatif::MultiProgress;
45use mistralrs_quant::log::once_log_info;
46use mistralrs_quant::{AfqLayer, GgufMatMul, HqqLayer, IsqType, QuantizedSerdeType};
47use rand_isaac::Isaac64Rng;
48use regex_automata::meta::Regex;
49use std::any::Any;
50use std::borrow::Cow;
51use std::path::{Path, PathBuf};
52use std::str::FromStr;
53use std::sync::{Arc, RwLock};
54use std::time::Instant;
55use std::{env, fs};
56use tokenizers::Tokenizer;
57use tokio::sync::Mutex;
58use tracing::{info, warn};
59
60pub struct VisionPipeline {
61 model: Box<dyn VisionModel + Send + Sync>,
62 tokenizer: Arc<Tokenizer>,
63 chat_template: Arc<ChatTemplate>,
64 model_id: String,
65 metadata: Arc<GeneralMetadata>,
66 processor: Arc<dyn Processor + Send + Sync>,
67 preprocessor_config: Arc<PreProcessorConfig>,
68 topology: Option<Topology>,
69 silent: bool,
70 prefixer: Arc<dyn MultimodalPromptPrefixer>,
71 mapper: Box<dyn DeviceMapper + Send + Sync>,
72
73 template_filename: Option<PathBuf>,
75 generation_config: Option<PathBuf>,
76 config: String,
77 processor_filename: Option<PathBuf>,
78 preprocessor_filename: Option<PathBuf>,
79 imatrix: Option<PathBuf>,
80}
81
82pub struct VisionLoader {
84 inner: Box<dyn VisionModelLoader>,
85 model_id: String,
86 config: VisionSpecificConfig,
87 kind: ModelKind,
88 chat_template: Option<String>,
89 tokenizer_json: Option<String>,
90 xlora_model_id: Option<String>,
91 xlora_order: Option<Ordering>,
92 token_source: RwLock<Option<TokenSource>>,
93 revision: RwLock<Option<String>>,
94 from_uqff: RwLock<Option<Vec<PathBuf>>>,
95 jinja_explicit: Option<String>,
96 hf_cache_path: Option<PathBuf>,
97 lora_adapter_ids: Option<Vec<String>>,
98}
99
100#[derive(Default)]
101pub struct VisionLoaderBuilder {
103 model_id: Option<String>,
104 config: VisionSpecificConfig,
105 kind: ModelKind,
106 chat_template: Option<String>,
107 tokenizer_json: Option<String>,
108 jinja_explicit: Option<String>,
109 hf_cache_path: Option<PathBuf>,
110 lora_adapter_ids: Option<Vec<String>>,
111}
112
113#[derive(Clone, Default)]
114pub struct VisionSpecificConfig {
116 pub topology: Option<Topology>,
117 pub write_uqff: Option<PathBuf>,
118 pub from_uqff: Option<Vec<PathBuf>>,
119 pub max_edge: Option<u32>,
120 pub imatrix: Option<PathBuf>,
121 pub calibration_file: Option<PathBuf>,
122 pub hf_cache_path: Option<PathBuf>,
123 pub matformer_config_path: Option<PathBuf>,
124 pub matformer_slice_name: Option<String>,
125}
126
127impl VisionLoaderBuilder {
128 pub fn new(
129 config: VisionSpecificConfig,
130 chat_template: Option<String>,
131 tokenizer_json: Option<String>,
132 model_id: Option<String>,
133 jinja_explicit: Option<String>,
134 ) -> Self {
135 Self {
136 config,
137 chat_template,
138 tokenizer_json,
139 model_id,
140 jinja_explicit,
141 kind: ModelKind::Normal,
142 hf_cache_path: None,
143 ..Default::default()
144 }
145 }
146
147 pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
148 self.hf_cache_path = Some(hf_cache_path);
149 self
150 }
151
152 pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
153 self.kind = ModelKind::Adapter {
154 adapter: AdapterKind::Lora,
155 };
156 self.lora_adapter_ids = Some(lora_adapter_ids);
157 self
158 }
159
160 pub fn build(self, loader: Option<VisionLoaderType>) -> Box<dyn Loader> {
161 let loader: Box<dyn VisionModelLoader> = match loader {
162 Some(VisionLoaderType::Phi3V) => Box::new(Phi3VLoader),
163 Some(VisionLoaderType::Idefics2) => Box::new(Idefics2Loader),
164 Some(VisionLoaderType::LLaVANext) => Box::new(LLaVANextLoader),
165 Some(VisionLoaderType::LLaVA) => Box::new(LLaVALoader),
166 Some(VisionLoaderType::VLlama) => Box::new(VLlamaLoader),
167 Some(VisionLoaderType::Qwen2VL) => Box::new(Qwen2VLLoader),
168 Some(VisionLoaderType::Idefics3) => Box::new(Idefics3Loader),
169 Some(VisionLoaderType::MiniCpmO) => Box::new(MiniCpmOLoader),
170 Some(VisionLoaderType::Phi4MM) => Box::new(Phi4MMLoader),
171 Some(VisionLoaderType::Qwen2_5VL) => Box::new(Qwen2_5VLLoader),
172 Some(VisionLoaderType::Gemma3) => Box::new(Gemma3Loader),
173 Some(VisionLoaderType::Mistral3) => Box::new(Mistral3Loader),
174 Some(VisionLoaderType::Llama4) => Box::new(VLlama4Loader),
175 Some(VisionLoaderType::Gemma3n) => Box::new(Gemma3nLoader),
176 Some(VisionLoaderType::Qwen3VL) => Box::new(Qwen3VLLoader),
177 None => Box::new(AutoVisionLoader),
178 };
179 Box::new(VisionLoader {
180 inner: loader,
181 model_id: self.model_id.unwrap(),
182 config: self.config,
183 kind: self.kind,
184 chat_template: self.chat_template,
185 tokenizer_json: self.tokenizer_json,
186 xlora_model_id: None,
187 xlora_order: None,
188 jinja_explicit: self.jinja_explicit,
189 token_source: RwLock::new(None),
190 revision: RwLock::new(None),
191 from_uqff: RwLock::new(None),
192 hf_cache_path: self.hf_cache_path,
193 lora_adapter_ids: self.lora_adapter_ids,
194 })
195 }
196}
197
198impl Loader for VisionLoader {
199 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
200 fn load_model_from_hf(
201 &self,
202 revision: Option<String>,
203 token_source: TokenSource,
204 dtype: &dyn TryIntoDType,
205 device: &Device,
206 silent: bool,
207 mapper: DeviceMapSetting,
208 in_situ_quant: Option<IsqType>,
209 paged_attn_config: Option<PagedAttentionConfig>,
210 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
211 let cache = self
212 .hf_cache_path
213 .clone()
214 .map(Cache::new)
215 .unwrap_or_default();
216 GLOBAL_HF_CACHE.get_or_init(|| cache);
217
218 let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths!(
219 LocalModelPaths,
220 &token_source,
221 revision.clone(),
222 self,
223 None,
224 None,
225 silent,
226 self.config.from_uqff.is_some()
227 );
228 if let Some(from_uqff) = self.config.from_uqff.clone() {
229 *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
230 }
231 *self
232 .token_source
233 .write()
234 .expect("Failed to write to token source") = Some(token_source);
235 *self.revision.write().expect("Failed to write to revision") = revision;
236 self.load_model_from_path(
237 &paths?,
238 dtype,
239 device,
240 silent,
241 mapper,
242 in_situ_quant,
243 paged_attn_config,
244 )
245 }
246
247 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
248 fn load_model_from_path(
249 &self,
250 paths: &Box<dyn ModelPaths>,
251 dtype: &dyn TryIntoDType,
252 device: &Device,
253 silent: bool,
254 mut mapper: DeviceMapSetting,
255 mut in_situ_quant: Option<IsqType>,
256 mut paged_attn_config: Option<PagedAttentionConfig>,
257 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
258 let config = std::fs::read_to_string(paths.get_config_filename())?;
259
260 if !self.inner.supports_paged_attention(&config) {
261 paged_attn_config = None;
262 }
263
264 info!("Prompt chunk size is {ATTENTION_CHUNK_SIZE}.");
265
266 let use_nccl = mistralrs_quant::distributed::use_nccl();
267
268 let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
269 let payload: WorkerTransferData = serde_json::from_str(&payload)?;
270 let WorkerTransferData::Init { id: _, worker_rank } = payload;
271 vec![candle_core::Device::new_cuda_with_stream(worker_rank + 1)?]
272 } else if use_nccl {
273 vec![candle_core::Device::new_cuda_with_stream(0)?]
274 } else {
275 device_map::get_all_similar_devices(device)?
276 };
277 #[cfg(feature = "cuda")]
278 for device in &available_devices {
279 if let Device::Cuda(dev) = device {
280 unsafe { dev.disable_event_tracking() };
281 }
282 }
283 let device = if use_nccl {
284 available_devices[0].clone()
285 } else {
286 device.clone()
287 };
288
289 let matformer_slicing_config = if let Some(matformer_path) =
291 &self.config.matformer_config_path
292 {
293 use crate::matformer::{MatformerConfig, MatformerSliceConfig};
294 info!("Loading Matformer config from {:?}", matformer_path);
295 let config = Arc::new(MatformerConfig::from_file(matformer_path)?);
296
297 if let Some(slice_name) = &self.config.matformer_slice_name {
298 info!("Using Matformer slice: {}", slice_name);
299 Some(MatformerSliceConfig::new(slice_name.clone(), config))
300 } else {
301 warn!("Matformer config loaded but no slice name specified. Models will use their default slice.");
304 None
305 }
306 } else {
307 None
308 };
309
310 if use_nccl {
312 mapper = DeviceMapSetting::DummyNccl {
313 nm_device: available_devices[0].clone(),
314 };
315 } else if let DeviceMapSetting::Auto(mut params) = mapper.clone() {
316 params = params.maybe_promote_to_vision();
318
319 let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
321
322 if QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)? != 1 {
324 in_situ_quant = None;
325 }
326
327 let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
330 if let Some(serialized) = &*self.from_uqff.read().unwrap() {
331 let weight_pack_factor = {
332 let ser_artifacts = unsafe {
333 candle_core::safetensors::MmapedSafetensors::multi(serialized)?
334 };
335 let mut total_pack_factors = 0;
336 let total_tensors = ser_artifacts.tensors().len();
337 for (_, artifact) in ser_artifacts.tensors() {
338 let artifact = artifact.data();
339 let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
341 let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
342 {
343 QuantizedSerdeType::Hqq => {
344 HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
345 .pack_factor(dtype)
346 }
347 QuantizedSerdeType::Gguf => {
348 GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
349 .pack_factor(dtype)
350 }
351 QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
352 QuantizedSerdeType::Unquant => 1,
353 QuantizedSerdeType::Afq => {
354 AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
355 .pack_factor(dtype)
356 }
357 };
358 total_pack_factors += pack_factor;
359 }
360
361 total_pack_factors / total_tensors
362 };
363
364 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
365 &config,
366 dtype,
367 weight_pack_factor,
368 matformer_slicing_config.as_ref(),
369 )?;
370 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
371 &config,
372 dtype,
373 weight_pack_factor,
374 matformer_slicing_config.as_ref(),
375 )?;
376 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
377 (
378 layer_sizes_in_bytes,
379 non_mapped_size_in_bytes,
380 layer_sizes_sum + non_mapped_size_in_bytes,
381 )
382 } else if let Some(isq) = in_situ_quant {
383 let weight_pack_factor = isq.pack_factor(dtype);
384 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
385 &config,
386 dtype,
387 weight_pack_factor,
388 matformer_slicing_config.as_ref(),
389 )?;
390 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
391 &config,
392 dtype,
393 weight_pack_factor,
394 matformer_slicing_config.as_ref(),
395 )?;
396 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
397 (
398 layer_sizes_in_bytes,
399 non_mapped_size_in_bytes,
400 layer_sizes_sum + non_mapped_size_in_bytes,
401 )
402 } else {
403 let weight_pack_factor =
405 QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)?;
406 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
407 &config,
408 dtype,
409 weight_pack_factor,
410 matformer_slicing_config.as_ref(),
411 )?;
412 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
413 &config,
414 dtype,
415 weight_pack_factor,
416 matformer_slicing_config.as_ref(),
417 )?;
418 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
419 (
420 layer_sizes_in_bytes,
421 non_mapped_size_in_bytes,
422 layer_sizes_sum + non_mapped_size_in_bytes,
423 )
424 };
425
426 let new = auto_device_map::get_device_layers(
427 &*self.inner,
428 &config,
429 self.inner.num_layers(&config)?,
430 layer_sizes_in_bytes,
431 non_mapped_size_in_bytes,
432 total_model_size_in_bytes,
433 &available_devices,
434 dtype,
435 ¶ms,
436 paged_attn_config.as_ref(),
437 )?;
438 mapper = DeviceMapSetting::Map(new);
439 }
440
441 let pipeline_mapper = mapper.into_mapper(
442 self.inner.num_layers(&config)?,
443 &device,
444 self.config.topology.as_ref(),
445 )?;
446 let mapper = mapper.into_mapper(
447 self.inner.num_layers(&config)?,
448 &device,
449 self.config.topology.as_ref(),
450 )?;
451 let mut layer_devices = Vec::new();
452 for layer in 0..self.inner.num_layers(&config)? {
453 let device = mapper.device_for(layer, false).cloned();
454 layer_devices.push(device);
455 }
456 let dtype = mapper.get_min_dtype(dtype)?;
457
458 let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
461 if mapping_uses_cpu && paged_attn_config.is_some() {
462 warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
463 paged_attn_config = None;
464 }
465
466 info!("Model config: {:?}", self.inner.get_config_repr(&config)?);
467 if crate::using_flash_attn() {
468 once_log_info("FlashAttention is enabled.");
469 }
470
471 let mut loading_isq = if self.config.imatrix.is_none()
473 && self.config.calibration_file.is_none()
474 && !device.is_cuda()
475 && self.config.write_uqff.is_none()
476 && in_situ_quant.is_some()
477 {
478 let predicates = self.inner.immediate_isq_predicates(&config)?;
479 info!("Applying ISQ to {in_situ_quant:?}");
480 if predicates.is_empty() {
481 warn!("No predicates for this model and ISQ setting detected. ISQ will not be applied to any weights!");
482 }
483 mistralrs_quant::set_immediate_isq(in_situ_quant, predicates);
484 false
485 } else {
486 in_situ_quant.is_some()
487 };
488
489 if let Some(ref topology) = self.config.topology {
490 loading_isq |= topology
491 .0
492 .iter()
493 .any(|layer| layer.as_ref().is_some_and(|layer| layer.isq.is_some()));
494 }
495
496 if self.config.imatrix.is_some() && self.config.calibration_file.is_some() {
497 anyhow::bail!(
498 "`imatrix` and `calibration_file` were both specified, this is not allowed."
499 );
500 }
501
502 let load_device = if !loading_isq || self.config.calibration_file.is_some() {
504 loading_isq = false;
505 device.clone()
506 } else {
507 Device::Cpu
508 };
509
510 let attention_mechanism = if paged_attn_config.is_some() {
511 AttentionImplementation::PagedAttention
512 } else {
513 AttentionImplementation::Eager
514 };
515
516 let multi_progress = Arc::new(MultiProgress::new());
517
518 let mut model = if use_nccl {
519 let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
520 dtype,
521 &device,
522 &available_devices,
523 silent,
524 &config,
525 loading_isq,
526 self.config.from_uqff.is_some(),
527 IsqOrganization::Default,
528 &*self.inner,
529 paths.as_ref(),
530 )?;
531
532 match self.kind {
534 ModelKind::Normal => vision_normal_model_loader_sharded!(
535 sharded_vb,
536 config,
537 self.inner,
538 mapper,
539 loading_isq,
540 device.clone(),
541 attention_mechanism,
542 multi_progress.clone(),
543 matformer_slicing_config.clone(),
544 ),
545 _ => unreachable!(),
546 }
547 } else {
548 match self.kind {
549 ModelKind::Normal => vision_normal_model_loader!(
550 paths,
551 Some(dtype),
552 &load_device,
553 layer_devices.clone(),
554 config,
555 self.inner,
556 silent,
557 mapper,
558 loading_isq,
559 self.config.from_uqff.is_some(),
560 device.clone(),
561 attention_mechanism,
562 multi_progress,
563 matformer_slicing_config.clone(),
564 ),
565 _ => unreachable!(),
566 }
567 };
568
569 let preprocessor_config: PreProcessorConfig = match paths.get_preprocessor_config().as_ref()
571 {
572 Some(preprocessor_config) => {
573 serde_json::from_str(&fs::read_to_string(preprocessor_config).unwrap()).unwrap()
574 }
575 None => PreProcessorConfig::default(),
576 };
577 let processor_config: Option<ProcessorConfig> = paths
578 .get_processor_config()
579 .as_ref()
580 .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
581
582 let processor = self.inner.get_processor(
583 &config,
584 processor_config,
585 preprocessor_config.clone(),
586 self.config.max_edge,
587 ); let tokenizer = get_tokenizer(
590 paths.get_tokenizer_filename(),
591 Some(processor.get_special_tokens()),
592 )?;
593
594 let gen_conf: Option<GenerationConfig> = paths
595 .get_gen_conf_filename()
596 .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
597 let chat_template_explicit = paths
598 .get_chat_template_explicit()
599 .as_ref()
600 .map(|x| x.to_string_lossy().to_string());
601 let chat_template = get_chat_template(
602 paths,
603 self.jinja_explicit.as_ref(),
604 chat_template_explicit.as_ref(),
605 self.chat_template.as_ref(),
606 None,
607 );
608
609 if let Some(calibration_file) = &self.config.calibration_file {
610 let calibration_data = std::fs::read_to_string(calibration_file)?;
611 let tokens = tokenizer
613 .encode_fast(calibration_data, false)
614 .map_err(anyhow::Error::msg)?
615 .get_ids()
616 .to_vec();
617 info!(
618 "Collecting imatrix from calibration file `{}` of {} tokens.",
619 calibration_file.display(),
620 tokens.len()
621 );
622 let bos_toks = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();
623 let bos_tok_id = tokenizer
624 .token_to_id(&bos_toks[0])
625 .expect("Somehow the bos token is not present.");
626
627 model.begin_track_stats()?;
630
631 const CHUNK_SIZE: usize = 1024;
632 let n_chunks: usize = tokens.len().div_ceil(CHUNK_SIZE);
633 let start = Instant::now();
634 for (i, chunk) in tokens.chunks(CHUNK_SIZE).enumerate() {
635 let chunk = [vec![bos_tok_id], chunk.to_vec()].concat();
636 let chunk_len = chunk.len();
637
638 let start = Instant::now();
639 let inputs = make_prompt_chunk(
640 0,
641 vec![&chunk],
642 &[0],
643 &load_device,
644 None,
645 false,
646 None,
647 None,
648 )?;
649 let _ = model.forward(
650 &inputs.input,
651 None, &inputs.positions,
653 inputs.context_lens,
654 inputs.position_ids,
655 model.default_model_specific_args(&inputs.input),
656 None,
657 &inputs.flash_meta,
658 )?;
659 match model.cache_mut() {
660 EitherCache::Full(full) => {
661 for layer in &mut *full.lock() {
662 *layer = None
663 }
664 }
665 EitherCache::Normal(normal) => {
666 for layer in &mut *normal.lock().unwrap().0 {
667 layer.reset();
668 }
669 }
670 }
671 let end = Instant::now();
672 info!(
673 "Processed chunk {}/{n_chunks} ({chunk_len} tokens), {:.2}s",
674 i + 1,
675 end.duration_since(start).as_secs_f32()
676 );
677 }
678 load_device.synchronize()?;
679 let end = Instant::now();
680 info!(
681 "Finished collecting imatrix in {:.2}s",
682 end.duration_since(start).as_secs_f32()
683 );
684 }
685
686 if (loading_isq || self.config.topology.is_some()) && self.config.from_uqff.is_none() {
688 let imatrix_source = match (
689 self.config.imatrix.as_ref(),
690 self.config.calibration_file.is_some(),
691 ) {
692 (None, false) => None,
693 (Some(file), false) => Some(ImatrixDataSource::File(file)),
694 (None, true) => Some(ImatrixDataSource::Collected),
695 (Some(_), true) => unreachable!(),
696 };
697 model.quantize(
698 in_situ_quant,
699 device.clone(),
700 self.config.topology.as_ref(),
701 silent,
702 imatrix_source,
703 IsqOrganization::Default,
704 self.config.write_uqff.as_ref(),
705 UqffFullSer {
706 tokenizer: &tokenizer,
707 template_filename: paths.get_template_filename(),
708 generation_config: paths.get_gen_conf_filename(),
709 config: config.clone(),
710 processor_filename: paths.get_processor_config(),
711 preprocessor_filename: paths.get_preprocessor_config(),
712 },
713 Arc::new(MultiProgress::new()),
714 )?;
715 } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
716 model.load_from_artifacts(
717 device.clone(),
718 self.config.topology.as_ref(),
719 silent,
720 from_uqff,
721 )?;
722 }
723
724 let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
725 anyhow::ensure!(
726 !matches!(self.kind, ModelKind::Adapter { .. }),
727 "PagedAttention does not support adapter models."
728 );
729 let cache_config = calculate_cache_config(
730 paged_attn_config.mem_gpu,
731 paged_attn_config.block_size,
732 dtype,
733 paged_attn_config.cache_type,
734 model.config(),
735 &device,
736 &layer_devices,
737 silent,
738 )?;
739 let cache_engine =
740 CacheEngine::new(model.config(), &cache_config, dtype, &device, layer_devices)?;
741 (Some(cache_config), Some(cache_engine))
742 } else {
743 (None, None)
744 };
745
746 let max_seq_len = model.max_seq_len();
747 let llg_factory = build_llg_factory(tokenizer.clone())?;
748 let num_hidden_layers = match model.cache() {
749 EitherCache::Full(full) => full.lock().len(),
750 EitherCache::Normal(normal) => normal.lock().unwrap().0.len(),
751 };
752 let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
753 let sliding_window = model.config().sliding_window;
754 let model_metadata = Arc::new(model.config().clone());
755 Ok(Arc::new(Mutex::new(VisionPipeline {
756 model,
757 tokenizer: tokenizer.into(),
758 chat_template: Arc::new(chat_template),
759 model_id: self.model_id.clone(),
760 metadata: Arc::new(GeneralMetadata {
761 max_seq_len,
762 llg_factory: Some(llg_factory),
763 is_xlora: false,
764 num_hidden_layers,
765 eos_tok: eos,
766 kind: self.kind.clone(),
767 no_kv_cache: false,
768 no_prefix_cache: !self.inner.supports_prefix_cacher(&config),
769 activation_dtype: dtype,
770 sliding_window,
771 cache_config,
772 cache_engine,
773 model_metadata: Some(model_metadata),
774 modalities: self.inner.modalities(&config)?,
775 }),
776 processor,
777 prefixer: self.inner.prefixer(&config),
778 preprocessor_config: Arc::new(preprocessor_config),
779 topology: self.config.topology.clone(),
780 silent,
781 template_filename: paths.get_template_filename().clone(),
782 generation_config: paths.get_gen_conf_filename().cloned(),
783 config,
784 processor_filename: paths.get_processor_config().clone(),
785 preprocessor_filename: paths.get_preprocessor_config().clone(),
786 mapper: pipeline_mapper,
787 imatrix: self.config.imatrix.clone(),
788 })))
789 }
790
791 fn get_id(&self) -> String {
792 self.model_id.to_string()
793 }
794
795 fn get_kind(&self) -> ModelKind {
796 self.kind.clone()
797 }
798}
799
800impl PreProcessingMixin for VisionPipeline {
801 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
802 Some(self.chat_template.clone())
803 }
804 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
805 Some(self.preprocessor_config.clone())
806 }
807 fn get_processor(&self) -> Arc<dyn super::Processor> {
808 self.processor.clone()
809 }
810}
811
812impl IsqPipelineMixin for VisionPipeline {
813 fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
814 let device = self.device().clone();
815 self.model
816 .quantize(
817 Some(dtype),
818 device,
819 self.topology.as_ref(),
820 self.silent,
821 self.imatrix.as_ref().map(ImatrixDataSource::File),
822 IsqOrganization::Default,
823 None,
824 UqffFullSer {
825 tokenizer: &self.tokenizer,
826 template_filename: &self.template_filename,
827 generation_config: self.generation_config.as_ref(),
828 config: self.config.clone(),
829 processor_filename: &self.processor_filename,
830 preprocessor_filename: &self.preprocessor_filename,
831 },
832 Arc::new(MultiProgress::new()),
833 )
834 .map_err(anyhow::Error::msg)
835 }
836}
837
838impl CacheManagerMixin for VisionPipeline {
839 fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
840 if matches!(self.model.cache(), EitherCache::Full(_)) {
841 FullCacheManager.clone_in_cache(self, seqs, false)
842 } else {
843 NormalCacheManager.clone_in_cache(self, seqs, false)
844 }
845 }
846 fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
847 if matches!(self.model.cache(), EitherCache::Full(_)) {
848 FullCacheManager.clone_out_cache(self, seqs, false)
849 } else {
850 NormalCacheManager.clone_out_cache(self, seqs, false)
851 }
852 }
853 fn set_none_cache(
854 &self,
855 seqs: &mut [&mut Sequence],
856 reset_non_granular: bool,
857 modify_draft_cache: bool,
858
859 load_preallocated_cache: bool,
860 ) {
861 if matches!(self.model.cache(), EitherCache::Full(_)) {
862 FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false);
863 } else {
864 NormalCacheManager.set_none_cache(
865 self,
866 seqs,
867 modify_draft_cache,
868 load_preallocated_cache,
869 );
870 }
871 if reset_non_granular {
872 self.reset_non_granular_state()
873 }
874 }
875 fn cache(&self) -> &EitherCache {
876 self.model.cache()
877 }
878}
879
880impl MetadataMixin for VisionPipeline {
881 fn device(&self) -> Device {
882 self.model.device().clone()
883 }
884 fn get_metadata(&self) -> Arc<GeneralMetadata> {
885 self.metadata.clone()
886 }
887 fn name(&self) -> String {
888 self.model_id.clone()
889 }
890 fn reset_non_granular_state(&self) {}
891 fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
892 Some(self.tokenizer.clone())
893 }
894 fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
895 Some(&*self.mapper)
896 }
897}
898
899#[async_trait::async_trait]
900impl Pipeline for VisionPipeline {
901 fn forward_inputs(
902 &mut self,
903 inputs: Box<dyn Any>,
904 return_raw_logits: bool,
905 ) -> candle_core::Result<ForwardInputsResult> {
906 let ModelInputs {
907 input_ids,
908 seqlen_offsets,
909 context_lens,
910 position_ids,
911 pixel_values,
912 model_specific_args,
913 paged_attn_meta,
914 flash_meta,
915 } = *inputs.downcast::<ModelInputs>().expect("Downcast failed.");
916 let metadata = self.get_metadata();
917 let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
918 (Some(engine), Some(meta)) => Some((engine.get_kv_cache().clone(), meta)),
919 (Some(_), None) => {
920 candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
922 }
923 (None, Some(_)) => {
924 candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
926 }
927 (None, None) => None,
928 };
929 let logits = self.model.forward(
930 &input_ids,
931 pixel_values,
932 &seqlen_offsets,
933 context_lens,
934 position_ids,
935 model_specific_args,
936 paged_attn_meta,
937 &flash_meta,
938 )?;
939 if return_raw_logits {
940 Ok(ForwardInputsResult::RawLogits { logits })
941 } else {
942 Ok(ForwardInputsResult::CausalGeneration { logits })
943 }
944 }
945 async fn sample_causal_gen(
946 &self,
947 seqs: &mut [&mut Sequence],
948 logits: Vec<Tensor>,
949 prefix_cacher: &mut PrefixCacheManagerV2,
950 disable_eos_stop: bool,
951 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
952 ) -> Result<(), candle_core::Error> {
953 sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
954 }
955 fn category(&self) -> ModelCategory {
956 ModelCategory::Vision {
957 prefixer: self.prefixer.clone(),
958 }
959 }
960}
961
962impl AnyMoePipelineMixin for VisionPipeline {
963 fn amoe_finish_training(&mut self, gate_model_id: Option<String>) -> candle_core::Result<()> {
964 self.model.finish_training(gate_model_id)
965 }
966 fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
967 self.model.get_vars()
968 }
969 fn amoe_base_model_trainable_params(&self) -> usize {
970 self.model.trainable_params()
971 }
972 fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
973 self.model.take_cached_gating_outputs()
974 }
975 fn amoe_create_layers(
976 &mut self,
977 model_ids: Vec<String>,
978 token: &TokenSource,
979 revision: Option<String>,
980 match_regex: &str,
981 config: crate::amoe::AnyMoeConfig,
982 dtype: candle_core::DType,
983 dev: &Device,
984 (prefix, mlp): (String, String),
985 layers: Vec<usize>,
986 expert_type: AnyMoeExpertType,
987 silent: bool,
988 gate_model_id: Option<String>,
989 ) -> candle_core::Result<()> {
990 let mut vbs = Vec::new();
991 let regex = Regex::new(match_regex).map_err(candle_core::Error::msg)?;
993 for model_id in model_ids {
994 let model_id_str = &model_id;
995 let model_id = Path::new(&model_id);
996
997 let api = {
998 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
999 let mut api = ApiBuilder::from_cache(cache)
1000 .with_progress(!silent)
1001 .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1002 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1003 api = api.with_cache_dir(x.into());
1004 }
1005 api.build().map_err(candle_core::Error::msg)?
1006 };
1007 let revision = revision.clone().unwrap_or("main".to_string());
1008 let api = api.repo(Repo::with_revision(
1009 model_id_str.clone(),
1010 RepoType::Model,
1011 revision.clone(),
1012 ));
1013
1014 let mut filenames = vec![];
1015 for rfilename in
1016 api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1017 {
1018 filenames.push(api_get_file!(api, &rfilename, model_id));
1019 }
1020
1021 let regex = regex.clone();
1022 let match_regex_clone = match_regex.to_string();
1023 let layers_clone = layers.clone();
1024 let vb = from_mmaped_safetensors(
1025 filenames,
1026 vec![],
1027 Some(dtype),
1028 dev,
1029 vec![None],
1030 silent,
1031 None,
1032 move |key| {
1033 if regex.is_match(&key) {
1034 let last_layer_idx = key.find(&match_regex_clone).unwrap() - 1;
1037 let first_layer_idx = key[..last_layer_idx].rfind('.').unwrap();
1038 let layer_n = key[first_layer_idx + 1..last_layer_idx]
1039 .parse::<usize>()
1040 .unwrap();
1041 layers_clone.contains(&layer_n) || layers_clone.is_empty()
1042 } else {
1043 false
1044 }
1045 },
1046 Arc::new(|_| DeviceForLoadTensor::Base),
1047 )?;
1048 vbs.push(vb);
1049 }
1050
1051 let gate_vb = if let Some(gate_model_id) = gate_model_id {
1052 let model_id_str = &gate_model_id;
1053 let model_id = Path::new(&gate_model_id);
1054
1055 let api = {
1056 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1057 let mut api = ApiBuilder::from_cache(cache)
1058 .with_progress(!silent)
1059 .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1060 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1061 api = api.with_cache_dir(x.into());
1062 }
1063 api.build().map_err(candle_core::Error::msg)?
1064 };
1065 let revision = revision.clone().unwrap_or("main".to_string());
1066 let api = api.repo(Repo::with_revision(
1067 model_id_str.clone(),
1068 RepoType::Model,
1069 revision.clone(),
1070 ));
1071
1072 let mut gate_filenames = vec![];
1073 for rfilename in
1074 api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1075 {
1076 gate_filenames.push(api_get_file!(api, &rfilename, model_id));
1077 }
1078 assert_eq!(
1079 gate_filenames.len(),
1080 1,
1081 "Gate model ID must contain only one .safetensors file"
1082 );
1083
1084 let vb = from_mmaped_safetensors(
1085 gate_filenames.clone(),
1086 vec![],
1087 Some(dtype),
1088 dev,
1089 vec![None],
1090 silent,
1091 None,
1092 |_| true,
1093 Arc::new(|_| DeviceForLoadTensor::Base),
1094 )?;
1095 info!(
1096 "Loaded gating layers from `{}`",
1097 gate_filenames[0].display()
1098 );
1099 Some(vb)
1100 } else {
1101 None
1102 };
1103
1104 self.model
1105 .create_anymoe_layers(vbs, config, (prefix, mlp), layers, expert_type, gate_vb)
1106 }
1107 fn amoe_supported(&self) -> bool {
1108 self.model.amoe_supported()
1109 }
1110}