1use super::isq::ImatrixDataSource;
2use super::isq::UqffFullSer;
3use super::{
4 get_model_paths, get_xlora_paths, AdapterKind, AnyMoePipelineMixin, AutoVisionLoader,
5 CacheManager, CacheManagerMixin, EitherCache, ForwardInputsResult, Gemma3Loader,
6 GeneralMetadata, IsqPipelineMixin, Loader, MetadataMixin, MiniCpmOLoader, ModelCategory,
7 ModelKind, ModelPaths, MultimodalPromptPrefixer, Phi4MMLoader, PreProcessingMixin, Processor,
8 Qwen2VLLoader, Qwen3VLLoader, Qwen3VLMoELoader, TokenSource, VLlama4Loader, VLlamaLoader,
9 VisionModel, VisionModelLoader,
10};
11use super::{
12 Gemma3nLoader, Idefics2Loader, Idefics3Loader, LLaVALoader, LLaVANextLoader, Mistral3Loader,
13 Phi3VLoader, Qwen2_5VLLoader, VisionLoaderType,
14};
15use crate::attention::ATTENTION_CHUNK_SIZE;
16use crate::device_map::{self, DeviceMapper};
17use crate::distributed::{self, WorkerTransferData};
18use crate::kv_cache::{FullCacheManager, NormalCacheManager};
19use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine};
20use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
21use crate::pipeline::llg::build_llg_factory;
22use crate::pipeline::loaders::auto_device_map;
23use crate::pipeline::loaders::QuantizationConfigShim;
24use crate::pipeline::sampling::sample_and_add_toks;
25use crate::pipeline::text_models_inputs_processor::make_prompt_chunk;
26use crate::pipeline::{get_chat_template, ChatTemplate, IsqOrganization, LocalModelPaths};
27use crate::prefix_cacher::PrefixCacheManagerV2;
28use crate::sequence::Sequence;
29use crate::utils::tokenizer::get_tokenizer;
30use crate::utils::varbuilder_utils::DeviceForLoadTensor;
31use crate::utils::{
32 progress::{new_multi_progress, ProgressScopeGuard},
33 tokens::get_token,
34 varbuilder_utils::from_mmaped_safetensors,
35};
36use crate::vision_models::preprocessor_config::PreProcessorConfig;
37use crate::vision_models::processor_config::ProcessorConfig;
38use crate::vision_models::ModelInputs;
39use crate::{
40 api_dir_list, api_get_file, get_paths, get_uqff_paths, vision_normal_model_loader,
41 vision_normal_model_loader_sharded, AnyMoeExpertType, DeviceMapSetting, Ordering,
42 PagedAttentionConfig, Pipeline, Topology, TryIntoDType, GLOBAL_HF_CACHE,
43};
44use anyhow::Result;
45use candle_core::{Device, Tensor, Var};
46use hf_hub::Cache;
47use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
48use mistralrs_quant::log::once_log_info;
49use mistralrs_quant::{
50 AfqLayer, GgufMatMul, HqqLayer, ImmediateIsqOverride, IsqType, QuantizedSerdeType,
51};
52use rand_isaac::Isaac64Rng;
53use regex_automata::meta::Regex;
54use std::any::Any;
55use std::borrow::Cow;
56use std::path::{Path, PathBuf};
57use std::str::FromStr;
58use std::sync::{Arc, RwLock};
59use std::time::Instant;
60use std::{env, fs};
61use tokenizers::Tokenizer;
62use tokio::sync::Mutex;
63use tracing::{info, warn};
64
65pub struct VisionPipeline {
66 model: Box<dyn VisionModel + Send + Sync>,
67 tokenizer: Arc<Tokenizer>,
68 chat_template: Arc<ChatTemplate>,
69 model_id: String,
70 metadata: Arc<GeneralMetadata>,
71 processor: Arc<dyn Processor + Send + Sync>,
72 preprocessor_config: Arc<PreProcessorConfig>,
73 topology: Option<Topology>,
74 silent: bool,
75 prefixer: Arc<dyn MultimodalPromptPrefixer>,
76 mapper: Box<dyn DeviceMapper + Send + Sync>,
77
78 template_filename: Option<PathBuf>,
80 generation_config: Option<PathBuf>,
81 config: String,
82 processor_filename: Option<PathBuf>,
83 preprocessor_filename: Option<PathBuf>,
84 imatrix: Option<PathBuf>,
85}
86
87pub struct VisionLoader {
89 inner: Box<dyn VisionModelLoader>,
90 model_id: String,
91 config: VisionSpecificConfig,
92 kind: ModelKind,
93 chat_template: Option<String>,
94 tokenizer_json: Option<String>,
95 xlora_model_id: Option<String>,
96 xlora_order: Option<Ordering>,
97 token_source: RwLock<Option<TokenSource>>,
98 revision: RwLock<Option<String>>,
99 from_uqff: RwLock<Option<Vec<PathBuf>>>,
100 jinja_explicit: Option<String>,
101 hf_cache_path: Option<PathBuf>,
102 lora_adapter_ids: Option<Vec<String>>,
103}
104
105#[derive(Default)]
106pub struct VisionLoaderBuilder {
108 model_id: Option<String>,
109 config: VisionSpecificConfig,
110 kind: ModelKind,
111 chat_template: Option<String>,
112 tokenizer_json: Option<String>,
113 jinja_explicit: Option<String>,
114 hf_cache_path: Option<PathBuf>,
115 lora_adapter_ids: Option<Vec<String>>,
116}
117
118#[derive(Clone, Default)]
119pub struct VisionSpecificConfig {
121 pub topology: Option<Topology>,
122 pub write_uqff: Option<PathBuf>,
123 pub from_uqff: Option<Vec<PathBuf>>,
124 pub max_edge: Option<u32>,
125 pub imatrix: Option<PathBuf>,
126 pub calibration_file: Option<PathBuf>,
127 pub hf_cache_path: Option<PathBuf>,
128 pub matformer_config_path: Option<PathBuf>,
129 pub matformer_slice_name: Option<String>,
130}
131
132impl VisionLoaderBuilder {
133 pub fn new(
134 config: VisionSpecificConfig,
135 chat_template: Option<String>,
136 tokenizer_json: Option<String>,
137 model_id: Option<String>,
138 jinja_explicit: Option<String>,
139 ) -> Self {
140 Self {
141 config,
142 chat_template,
143 tokenizer_json,
144 model_id,
145 jinja_explicit,
146 kind: ModelKind::Normal,
147 hf_cache_path: None,
148 ..Default::default()
149 }
150 }
151
152 pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
153 self.hf_cache_path = Some(hf_cache_path);
154 self
155 }
156
157 pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
158 self.kind = ModelKind::Adapter {
159 adapter: AdapterKind::Lora,
160 };
161 self.lora_adapter_ids = Some(lora_adapter_ids);
162 self
163 }
164
165 pub fn build(self, loader: Option<VisionLoaderType>) -> Box<dyn Loader> {
166 let loader: Box<dyn VisionModelLoader> = match loader {
167 Some(VisionLoaderType::Phi3V) => Box::new(Phi3VLoader),
168 Some(VisionLoaderType::Idefics2) => Box::new(Idefics2Loader),
169 Some(VisionLoaderType::LLaVANext) => Box::new(LLaVANextLoader),
170 Some(VisionLoaderType::LLaVA) => Box::new(LLaVALoader),
171 Some(VisionLoaderType::VLlama) => Box::new(VLlamaLoader),
172 Some(VisionLoaderType::Qwen2VL) => Box::new(Qwen2VLLoader),
173 Some(VisionLoaderType::Idefics3) => Box::new(Idefics3Loader),
174 Some(VisionLoaderType::MiniCpmO) => Box::new(MiniCpmOLoader),
175 Some(VisionLoaderType::Phi4MM) => Box::new(Phi4MMLoader),
176 Some(VisionLoaderType::Qwen2_5VL) => Box::new(Qwen2_5VLLoader),
177 Some(VisionLoaderType::Gemma3) => Box::new(Gemma3Loader),
178 Some(VisionLoaderType::Mistral3) => Box::new(Mistral3Loader),
179 Some(VisionLoaderType::Llama4) => Box::new(VLlama4Loader),
180 Some(VisionLoaderType::Gemma3n) => Box::new(Gemma3nLoader),
181 Some(VisionLoaderType::Qwen3VL) => Box::new(Qwen3VLLoader),
182 Some(VisionLoaderType::Qwen3VLMoE) => Box::new(Qwen3VLMoELoader),
183 None => Box::new(AutoVisionLoader),
184 };
185 Box::new(VisionLoader {
186 inner: loader,
187 model_id: self.model_id.unwrap(),
188 config: self.config,
189 kind: self.kind,
190 chat_template: self.chat_template,
191 tokenizer_json: self.tokenizer_json,
192 xlora_model_id: None,
193 xlora_order: None,
194 jinja_explicit: self.jinja_explicit,
195 token_source: RwLock::new(None),
196 revision: RwLock::new(None),
197 from_uqff: RwLock::new(None),
198 hf_cache_path: self.hf_cache_path,
199 lora_adapter_ids: self.lora_adapter_ids,
200 })
201 }
202}
203
204impl Loader for VisionLoader {
205 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
206 fn load_model_from_hf(
207 &self,
208 revision: Option<String>,
209 token_source: TokenSource,
210 dtype: &dyn TryIntoDType,
211 device: &Device,
212 silent: bool,
213 mapper: DeviceMapSetting,
214 in_situ_quant: Option<IsqType>,
215 paged_attn_config: Option<PagedAttentionConfig>,
216 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
217 let _progress_guard = ProgressScopeGuard::new(silent);
218 let cache = self
219 .hf_cache_path
220 .clone()
221 .map(Cache::new)
222 .unwrap_or_default();
223 GLOBAL_HF_CACHE.get_or_init(|| cache);
224
225 let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths!(
226 LocalModelPaths,
227 &token_source,
228 revision.clone(),
229 self,
230 None,
231 None,
232 silent,
233 self.config.from_uqff.is_some()
234 );
235 if let Some(from_uqff) = self.config.from_uqff.clone() {
236 *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
237 }
238 *self
239 .token_source
240 .write()
241 .expect("Failed to write to token source") = Some(token_source);
242 *self.revision.write().expect("Failed to write to revision") = revision;
243 self.load_model_from_path(
244 &paths?,
245 dtype,
246 device,
247 silent,
248 mapper,
249 in_situ_quant,
250 paged_attn_config,
251 )
252 }
253
254 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
255 fn load_model_from_path(
256 &self,
257 paths: &Box<dyn ModelPaths>,
258 dtype: &dyn TryIntoDType,
259 device: &Device,
260 silent: bool,
261 mut mapper: DeviceMapSetting,
262 mut in_situ_quant: Option<IsqType>,
263 mut paged_attn_config: Option<PagedAttentionConfig>,
264 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
265 let _progress_guard = ProgressScopeGuard::new(silent);
266 let config = std::fs::read_to_string(paths.get_config_filename())?;
267
268 if !self.inner.supports_paged_attention(&config) {
269 paged_attn_config = None;
270 }
271
272 info!("Prompt chunk size is {ATTENTION_CHUNK_SIZE}.");
273
274 let use_nccl = mistralrs_quant::distributed::use_nccl();
275
276 let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
277 let payload: WorkerTransferData = serde_json::from_str(&payload)?;
278 let WorkerTransferData::Init { id: _, worker_rank } = payload;
279 vec![candle_core::Device::new_cuda_with_stream(worker_rank + 1)?]
280 } else if use_nccl {
281 vec![candle_core::Device::new_cuda_with_stream(0)?]
282 } else {
283 device_map::get_all_similar_devices(device)?
284 };
285 #[cfg(feature = "cuda")]
286 for device in &available_devices {
287 if let Device::Cuda(dev) = device {
288 unsafe { dev.disable_event_tracking() };
289 }
290 }
291 let device = if use_nccl {
292 available_devices[0].clone()
293 } else {
294 device.clone()
295 };
296
297 let matformer_slicing_config = if let Some(matformer_path) =
299 &self.config.matformer_config_path
300 {
301 use crate::matformer::{MatformerConfig, MatformerSliceConfig};
302 info!("Loading Matformer config from {:?}", matformer_path);
303 let config = Arc::new(MatformerConfig::from_file(matformer_path)?);
304
305 if let Some(slice_name) = &self.config.matformer_slice_name {
306 info!("Using Matformer slice: {}", slice_name);
307 Some(MatformerSliceConfig::new(slice_name.clone(), config))
308 } else {
309 warn!("Matformer config loaded but no slice name specified. Models will use their default slice.");
312 None
313 }
314 } else {
315 None
316 };
317
318 if use_nccl {
320 mapper = DeviceMapSetting::DummyNccl {
321 nm_device: available_devices[0].clone(),
322 };
323 } else if let DeviceMapSetting::Auto(mut params) = mapper.clone() {
324 params = params.maybe_promote_to_vision();
326
327 let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
329
330 if QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)? != 1 {
332 in_situ_quant = None;
333 }
334
335 let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
338 if let Some(serialized) = &*self.from_uqff.read().unwrap() {
339 let weight_pack_factor = {
340 let ser_artifacts = unsafe {
341 candle_core::safetensors::MmapedSafetensors::multi(serialized)?
342 };
343 let mut total_pack_factors = 0;
344 let total_tensors = ser_artifacts.tensors().len();
345 for (_, artifact) in ser_artifacts.tensors() {
346 let artifact = artifact.data();
347 let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
349 let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
350 {
351 QuantizedSerdeType::Hqq => {
352 HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
353 .pack_factor(dtype)
354 }
355 QuantizedSerdeType::Gguf => {
356 GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
357 .pack_factor(dtype)
358 }
359 QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
360 QuantizedSerdeType::Unquant => 1,
361 QuantizedSerdeType::Afq => {
362 AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
363 .pack_factor(dtype)
364 }
365 };
366 total_pack_factors += pack_factor;
367 }
368
369 total_pack_factors / total_tensors
370 };
371
372 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
373 &config,
374 dtype,
375 weight_pack_factor,
376 matformer_slicing_config.as_ref(),
377 )?;
378 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
379 &config,
380 dtype,
381 weight_pack_factor,
382 matformer_slicing_config.as_ref(),
383 )?;
384 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
385 (
386 layer_sizes_in_bytes,
387 non_mapped_size_in_bytes,
388 layer_sizes_sum + non_mapped_size_in_bytes,
389 )
390 } else if let Some(isq) = in_situ_quant {
391 let weight_pack_factor = isq.pack_factor(dtype);
392 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
393 &config,
394 dtype,
395 weight_pack_factor,
396 matformer_slicing_config.as_ref(),
397 )?;
398 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
399 &config,
400 dtype,
401 weight_pack_factor,
402 matformer_slicing_config.as_ref(),
403 )?;
404 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
405 (
406 layer_sizes_in_bytes,
407 non_mapped_size_in_bytes,
408 layer_sizes_sum + non_mapped_size_in_bytes,
409 )
410 } else {
411 let weight_pack_factor =
413 QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)?;
414 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
415 &config,
416 dtype,
417 weight_pack_factor,
418 matformer_slicing_config.as_ref(),
419 )?;
420 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
421 &config,
422 dtype,
423 weight_pack_factor,
424 matformer_slicing_config.as_ref(),
425 )?;
426 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
427 (
428 layer_sizes_in_bytes,
429 non_mapped_size_in_bytes,
430 layer_sizes_sum + non_mapped_size_in_bytes,
431 )
432 };
433
434 let new = auto_device_map::get_device_layers(
435 &*self.inner,
436 &config,
437 self.inner.num_layers(&config)?,
438 layer_sizes_in_bytes,
439 non_mapped_size_in_bytes,
440 total_model_size_in_bytes,
441 &available_devices,
442 dtype,
443 ¶ms,
444 paged_attn_config.as_ref(),
445 )?;
446 mapper = DeviceMapSetting::Map(new);
447 }
448
449 let pipeline_mapper = mapper.into_mapper(
450 self.inner.num_layers(&config)?,
451 &device,
452 self.config.topology.as_ref(),
453 )?;
454 let mapper = mapper.into_mapper(
455 self.inner.num_layers(&config)?,
456 &device,
457 self.config.topology.as_ref(),
458 )?;
459 let mut layer_devices = Vec::new();
460 for layer in 0..self.inner.num_layers(&config)? {
461 let device = mapper.device_for(layer, false).cloned();
462 layer_devices.push(device);
463 }
464 let dtype = mapper.get_min_dtype(dtype)?;
465
466 let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
469 if mapping_uses_cpu && paged_attn_config.is_some() {
470 warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
471 paged_attn_config = None;
472 }
473
474 info!("Model config: {:?}", self.inner.get_config_repr(&config)?);
475 if crate::using_flash_attn() {
476 once_log_info("FlashAttention is enabled.");
477 }
478
479 let topology_overrides = self
480 .config
481 .topology
482 .as_ref()
483 .map(|topology| {
484 topology
485 .pattern_overrides()
486 .into_iter()
487 .map(|(regex, layer)| ImmediateIsqOverride {
488 predicate: regex,
489 ty: layer.isq,
490 device: layer.device.clone(),
491 })
492 .collect::<Vec<_>>()
493 })
494 .unwrap_or_default();
495 let has_override_isq = topology_overrides
496 .iter()
497 .any(|override_entry| override_entry.ty.is_some());
498 let topology_requires_post_quant = self
499 .config
500 .topology
501 .as_ref()
502 .is_some_and(|topology| topology.requires_post_quantization());
503
504 let allow_immediate_cli = self.config.imatrix.is_none()
505 && self.config.calibration_file.is_none()
506 && !device.is_cuda()
507 && in_situ_quant.is_some();
508
509 let mut immediate_ty = None;
510 let mut immediate_predicates = Vec::new();
511 if allow_immediate_cli {
512 immediate_ty = in_situ_quant;
513 immediate_predicates = self.inner.immediate_isq_predicates(&config)?;
514 info!("Applying ISQ to {in_situ_quant:?}");
515 if immediate_predicates.is_empty() {
516 warn!("No predicates for this model and ISQ setting detected. ISQ will not be applied to any weights!");
517 }
518 }
519
520 let use_immediate = allow_immediate_cli || has_override_isq;
521 if use_immediate {
522 mistralrs_quant::set_immediate_isq_with_overrides(
523 immediate_ty,
524 immediate_predicates.clone(),
525 topology_overrides.clone(),
526 );
527 }
528
529 let mut loading_isq = if use_immediate {
531 false
532 } else {
533 in_situ_quant.is_some()
534 };
535 if self.config.imatrix.is_some() || self.config.calibration_file.is_some() {
536 loading_isq = true;
537 }
538 loading_isq |= topology_requires_post_quant;
539
540 if self.config.imatrix.is_some() && self.config.calibration_file.is_some() {
541 anyhow::bail!(
542 "`imatrix` and `calibration_file` were both specified, this is not allowed."
543 );
544 }
545
546 let load_device = if !loading_isq || self.config.calibration_file.is_some() {
548 loading_isq = false;
549 device.clone()
550 } else {
551 Device::Cpu
552 };
553
554 let attention_mechanism = if paged_attn_config.is_some() {
555 AttentionImplementation::PagedAttention
556 } else {
557 AttentionImplementation::Eager
558 };
559
560 let multi_progress = Arc::new(new_multi_progress());
561
562 let mut model = if use_nccl {
563 let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
564 dtype,
565 &device,
566 &available_devices,
567 silent,
568 &config,
569 loading_isq,
570 self.config.from_uqff.is_some(),
571 IsqOrganization::Default,
572 &*self.inner,
573 paths.as_ref(),
574 )?;
575
576 match self.kind {
578 ModelKind::Normal => vision_normal_model_loader_sharded!(
579 sharded_vb,
580 config,
581 self.inner,
582 mapper,
583 loading_isq,
584 device.clone(),
585 attention_mechanism,
586 multi_progress.clone(),
587 matformer_slicing_config.clone(),
588 ),
589 _ => unreachable!(),
590 }
591 } else {
592 match self.kind {
593 ModelKind::Normal => vision_normal_model_loader!(
594 paths,
595 Some(dtype),
596 &load_device,
597 layer_devices.clone(),
598 config,
599 self.inner,
600 silent,
601 mapper,
602 loading_isq,
603 self.config.from_uqff.is_some(),
604 device.clone(),
605 attention_mechanism,
606 multi_progress,
607 matformer_slicing_config.clone(),
608 ),
609 _ => unreachable!(),
610 }
611 };
612
613 let preprocessor_config: PreProcessorConfig = match paths.get_preprocessor_config().as_ref()
615 {
616 Some(preprocessor_config) => {
617 serde_json::from_str(&fs::read_to_string(preprocessor_config).unwrap()).unwrap()
618 }
619 None => PreProcessorConfig::default(),
620 };
621 let processor_config: Option<ProcessorConfig> = paths
622 .get_processor_config()
623 .as_ref()
624 .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
625
626 let processor = self.inner.get_processor(
627 &config,
628 processor_config,
629 preprocessor_config.clone(),
630 self.config.max_edge,
631 ); let tokenizer = get_tokenizer(
634 paths.get_tokenizer_filename(),
635 Some(processor.get_special_tokens()),
636 )?;
637
638 let gen_conf: Option<GenerationConfig> = paths
639 .get_gen_conf_filename()
640 .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
641 let chat_template_explicit = paths
642 .get_chat_template_explicit()
643 .as_ref()
644 .map(|x| x.to_string_lossy().to_string());
645 let chat_template = get_chat_template(
646 paths,
647 self.jinja_explicit.as_ref(),
648 chat_template_explicit.as_ref(),
649 self.chat_template.as_ref(),
650 None,
651 );
652
653 if let Some(calibration_file) = &self.config.calibration_file {
654 let calibration_data = std::fs::read_to_string(calibration_file)?;
655 let tokens = tokenizer
657 .encode_fast(calibration_data, false)
658 .map_err(anyhow::Error::msg)?
659 .get_ids()
660 .to_vec();
661 info!(
662 "Collecting imatrix from calibration file `{}` of {} tokens.",
663 calibration_file.display(),
664 tokens.len()
665 );
666 let bos_toks = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();
667 let bos_tok_id = tokenizer
668 .token_to_id(&bos_toks[0])
669 .expect("Somehow the bos token is not present.");
670
671 model.begin_track_stats()?;
674
675 const CHUNK_SIZE: usize = 1024;
676 let n_chunks: usize = tokens.len().div_ceil(CHUNK_SIZE);
677 let start = Instant::now();
678 for (i, chunk) in tokens.chunks(CHUNK_SIZE).enumerate() {
679 let chunk = [vec![bos_tok_id], chunk.to_vec()].concat();
680 let chunk_len = chunk.len();
681
682 let start = Instant::now();
683 let inputs = make_prompt_chunk(
684 0,
685 vec![&chunk],
686 &[0],
687 &load_device,
688 None,
689 false,
690 None,
691 None,
692 )?;
693 let _ = model.forward(
694 &inputs.input,
695 None, &inputs.positions,
697 inputs.context_lens,
698 inputs.position_ids,
699 model.default_model_specific_args(&inputs.input),
700 None,
701 &inputs.flash_meta,
702 )?;
703 match model.cache_mut() {
704 EitherCache::Full(full) => {
705 for layer in &mut *full.lock() {
706 *layer = None
707 }
708 }
709 EitherCache::Normal(normal) => {
710 for layer in &mut *normal.lock().unwrap().0 {
711 layer.reset();
712 }
713 }
714 EitherCache::Hybrid(hybrid) => {
715 hybrid.lock().unwrap().reset();
716 }
717 }
718 let end = Instant::now();
719 info!(
720 "Processed chunk {}/{n_chunks} ({chunk_len} tokens), {:.2}s",
721 i + 1,
722 end.duration_since(start).as_secs_f32()
723 );
724 }
725 load_device.synchronize()?;
726 let end = Instant::now();
727 info!(
728 "Finished collecting imatrix in {:.2}s",
729 end.duration_since(start).as_secs_f32()
730 );
731 }
732
733 let should_serialize = self.config.write_uqff.is_some();
734 let should_quantize_pass = loading_isq;
735
736 if (should_quantize_pass || should_serialize) && self.config.from_uqff.is_none() {
737 let imatrix_source = if should_quantize_pass {
738 match (
739 self.config.imatrix.as_ref(),
740 self.config.calibration_file.is_some(),
741 ) {
742 (None, false) => None,
743 (Some(file), false) => Some(ImatrixDataSource::File(file)),
744 (None, true) => Some(ImatrixDataSource::Collected),
745 (Some(_), true) => unreachable!(),
746 }
747 } else {
748 None
749 };
750 if should_quantize_pass {
751 info!("Applying ISQ to all ranks.");
752 } else {
753 info!("Serializing existing ISQ tensors without additional quantization.");
754 }
755 model.quantize(
756 in_situ_quant,
757 device.clone(),
758 self.config.topology.as_ref(),
759 silent,
760 imatrix_source,
761 IsqOrganization::Default,
762 should_quantize_pass,
763 self.config.write_uqff.as_ref(),
764 UqffFullSer {
765 tokenizer: &tokenizer,
766 template_filename: paths.get_template_filename(),
767 generation_config: paths.get_gen_conf_filename(),
768 config: config.clone(),
769 processor_filename: paths.get_processor_config(),
770 preprocessor_filename: paths.get_preprocessor_config(),
771 modules: None,
772 module_paths: None,
773 },
774 Arc::new(new_multi_progress()),
775 )?;
776 } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
777 model.load_from_artifacts(
778 device.clone(),
779 self.config.topology.as_ref(),
780 silent,
781 from_uqff,
782 )?;
783 }
784
785 let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
786 anyhow::ensure!(
787 !matches!(self.kind, ModelKind::Adapter { .. }),
788 "PagedAttention does not support adapter models."
789 );
790 let cache_config = calculate_cache_config(
791 paged_attn_config.mem_gpu,
792 paged_attn_config.block_size,
793 dtype,
794 paged_attn_config.cache_type,
795 model.config(),
796 &device,
797 &layer_devices,
798 silent,
799 )?;
800 let cache_engine =
801 CacheEngine::new(model.config(), &cache_config, dtype, &device, layer_devices)?;
802 (Some(cache_config), Some(cache_engine))
803 } else {
804 (None, None)
805 };
806
807 let max_seq_len = model.max_seq_len();
808 let llg_factory = build_llg_factory(tokenizer.clone())?;
809 let num_hidden_layers = match model.cache() {
810 EitherCache::Full(full) => full.lock().len(),
811 EitherCache::Normal(normal) => normal.lock().unwrap().0.len(),
812 EitherCache::Hybrid(hybrid) => hybrid.lock().unwrap().num_layers(),
813 };
814 let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
815 let sliding_window = model.config().sliding_window;
816 let model_metadata = Arc::new(model.config().clone());
817 Ok(Arc::new(Mutex::new(VisionPipeline {
818 model,
819 tokenizer: tokenizer.into(),
820 chat_template: Arc::new(chat_template),
821 model_id: self.model_id.clone(),
822 metadata: Arc::new(GeneralMetadata {
823 max_seq_len,
824 llg_factory: Some(llg_factory),
825 is_xlora: false,
826 num_hidden_layers,
827 eos_tok: eos,
828 kind: self.kind.clone(),
829 no_kv_cache: false,
830 no_prefix_cache: !self.inner.supports_prefix_cacher(&config),
831 activation_dtype: dtype,
832 sliding_window,
833 cache_config,
834 cache_engine,
835 model_metadata: Some(model_metadata),
836 modalities: self.inner.modalities(&config)?,
837 }),
838 processor,
839 prefixer: self.inner.prefixer(&config),
840 preprocessor_config: Arc::new(preprocessor_config),
841 topology: self.config.topology.clone(),
842 silent,
843 template_filename: paths.get_template_filename().clone(),
844 generation_config: paths.get_gen_conf_filename().cloned(),
845 config,
846 processor_filename: paths.get_processor_config().clone(),
847 preprocessor_filename: paths.get_preprocessor_config().clone(),
848 mapper: pipeline_mapper,
849 imatrix: self.config.imatrix.clone(),
850 })))
851 }
852
853 fn get_id(&self) -> String {
854 self.model_id.to_string()
855 }
856
857 fn get_kind(&self) -> ModelKind {
858 self.kind.clone()
859 }
860}
861
862impl PreProcessingMixin for VisionPipeline {
863 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
864 Some(self.chat_template.clone())
865 }
866 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
867 Some(self.preprocessor_config.clone())
868 }
869 fn get_processor(&self) -> Arc<dyn super::Processor> {
870 self.processor.clone()
871 }
872}
873
874impl IsqPipelineMixin for VisionPipeline {
875 fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
876 let device = self.device().clone();
877 self.model
878 .quantize(
879 Some(dtype),
880 device,
881 self.topology.as_ref(),
882 self.silent,
883 self.imatrix.as_ref().map(ImatrixDataSource::File),
884 IsqOrganization::Default,
885 true,
886 None,
887 UqffFullSer {
888 tokenizer: &self.tokenizer,
889 template_filename: &self.template_filename,
890 generation_config: self.generation_config.as_ref(),
891 config: self.config.clone(),
892 processor_filename: &self.processor_filename,
893 preprocessor_filename: &self.preprocessor_filename,
894 modules: None,
895 module_paths: None,
896 },
897 Arc::new(new_multi_progress()),
898 )
899 .map_err(anyhow::Error::msg)
900 }
901}
902
903impl CacheManagerMixin for VisionPipeline {
904 fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
905 if matches!(self.model.cache(), EitherCache::Full(_)) {
906 FullCacheManager.clone_in_cache(self, seqs, false)
907 } else {
908 NormalCacheManager.clone_in_cache(self, seqs, false)
909 }
910 }
911 fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
912 if matches!(self.model.cache(), EitherCache::Full(_)) {
913 FullCacheManager.clone_out_cache(self, seqs, false)
914 } else {
915 NormalCacheManager.clone_out_cache(self, seqs, false)
916 }
917 }
918 fn set_none_cache(
919 &self,
920 seqs: &mut [&mut Sequence],
921 reset_non_granular: bool,
922 modify_draft_cache: bool,
923
924 load_preallocated_cache: bool,
925 ) {
926 if matches!(self.model.cache(), EitherCache::Full(_)) {
927 FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false);
928 } else {
929 NormalCacheManager.set_none_cache(
930 self,
931 seqs,
932 modify_draft_cache,
933 load_preallocated_cache,
934 );
935 }
936 if reset_non_granular {
937 self.reset_non_granular_state()
938 }
939 }
940 fn cache(&self) -> &EitherCache {
941 self.model.cache()
942 }
943}
944
945impl MetadataMixin for VisionPipeline {
946 fn device(&self) -> Device {
947 self.model.device().clone()
948 }
949 fn get_metadata(&self) -> Arc<GeneralMetadata> {
950 self.metadata.clone()
951 }
952 fn name(&self) -> String {
953 self.model_id.clone()
954 }
955 fn reset_non_granular_state(&self) {}
956 fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
957 Some(self.tokenizer.clone())
958 }
959 fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
960 Some(&*self.mapper)
961 }
962}
963
964#[async_trait::async_trait]
965impl Pipeline for VisionPipeline {
966 fn forward_inputs(
967 &mut self,
968 inputs: Box<dyn Any>,
969 return_raw_logits: bool,
970 ) -> candle_core::Result<ForwardInputsResult> {
971 let ModelInputs {
972 input_ids,
973 seqlen_offsets,
974 context_lens,
975 position_ids,
976 pixel_values,
977 model_specific_args,
978 paged_attn_meta,
979 flash_meta,
980 } = *inputs.downcast::<ModelInputs>().expect("Downcast failed.");
981 let metadata = self.get_metadata();
982 let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
983 (Some(engine), Some(meta)) => Some((engine.get_kv_cache().clone(), meta)),
984 (Some(_), None) => {
985 candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
987 }
988 (None, Some(_)) => {
989 candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
991 }
992 (None, None) => None,
993 };
994 let logits = self.model.forward(
995 &input_ids,
996 pixel_values,
997 &seqlen_offsets,
998 context_lens,
999 position_ids,
1000 model_specific_args,
1001 paged_attn_meta,
1002 &flash_meta,
1003 )?;
1004 if return_raw_logits {
1005 Ok(ForwardInputsResult::RawLogits { logits })
1006 } else {
1007 Ok(ForwardInputsResult::CausalGeneration { logits })
1008 }
1009 }
1010 async fn sample_causal_gen(
1011 &self,
1012 seqs: &mut [&mut Sequence],
1013 logits: Vec<Tensor>,
1014 prefix_cacher: &mut PrefixCacheManagerV2,
1015 disable_eos_stop: bool,
1016 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
1017 ) -> Result<(), candle_core::Error> {
1018 sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
1019 }
1020 fn category(&self) -> ModelCategory {
1021 ModelCategory::Vision {
1022 prefixer: self.prefixer.clone(),
1023 }
1024 }
1025}
1026
1027impl AnyMoePipelineMixin for VisionPipeline {
1028 fn amoe_finish_training(&mut self, gate_model_id: Option<String>) -> candle_core::Result<()> {
1029 self.model.finish_training(gate_model_id)
1030 }
1031 fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
1032 self.model.get_vars()
1033 }
1034 fn amoe_base_model_trainable_params(&self) -> usize {
1035 self.model.trainable_params()
1036 }
1037 fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
1038 self.model.take_cached_gating_outputs()
1039 }
1040 fn amoe_create_layers(
1041 &mut self,
1042 model_ids: Vec<String>,
1043 token: &TokenSource,
1044 revision: Option<String>,
1045 match_regex: &str,
1046 config: crate::amoe::AnyMoeConfig,
1047 dtype: candle_core::DType,
1048 dev: &Device,
1049 (prefix, mlp): (String, String),
1050 layers: Vec<usize>,
1051 expert_type: AnyMoeExpertType,
1052 silent: bool,
1053 gate_model_id: Option<String>,
1054 ) -> candle_core::Result<()> {
1055 let mut vbs = Vec::new();
1056 let regex = Regex::new(match_regex).map_err(candle_core::Error::msg)?;
1058 for model_id in model_ids {
1059 let model_id_str = &model_id;
1060 let model_id = Path::new(&model_id);
1061
1062 let api = {
1063 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1064 let mut api = ApiBuilder::from_cache(cache)
1065 .with_progress(!silent)
1066 .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1067 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1068 api = api.with_cache_dir(x.into());
1069 }
1070 api.build().map_err(candle_core::Error::msg)?
1071 };
1072 let revision = revision.clone().unwrap_or("main".to_string());
1073 let api = api.repo(Repo::with_revision(
1074 model_id_str.clone(),
1075 RepoType::Model,
1076 revision.clone(),
1077 ));
1078
1079 let mut filenames = vec![];
1080 for rfilename in
1081 api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1082 {
1083 filenames.push(api_get_file!(api, &rfilename, model_id));
1084 }
1085
1086 let regex = regex.clone();
1087 let match_regex_clone = match_regex.to_string();
1088 let layers_clone = layers.clone();
1089 let vb = from_mmaped_safetensors(
1090 filenames,
1091 vec![],
1092 Some(dtype),
1093 dev,
1094 vec![None],
1095 silent,
1096 None,
1097 move |key| {
1098 if regex.is_match(&key) {
1099 let last_layer_idx = key.find(&match_regex_clone).unwrap() - 1;
1102 let first_layer_idx = key[..last_layer_idx].rfind('.').unwrap();
1103 let layer_n = key[first_layer_idx + 1..last_layer_idx]
1104 .parse::<usize>()
1105 .unwrap();
1106 layers_clone.contains(&layer_n) || layers_clone.is_empty()
1107 } else {
1108 false
1109 }
1110 },
1111 Arc::new(|_| DeviceForLoadTensor::Base),
1112 )?;
1113 vbs.push(vb);
1114 }
1115
1116 let gate_vb = if let Some(gate_model_id) = gate_model_id {
1117 let model_id_str = &gate_model_id;
1118 let model_id = Path::new(&gate_model_id);
1119
1120 let api = {
1121 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1122 let mut api = ApiBuilder::from_cache(cache)
1123 .with_progress(!silent)
1124 .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1125 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1126 api = api.with_cache_dir(x.into());
1127 }
1128 api.build().map_err(candle_core::Error::msg)?
1129 };
1130 let revision = revision.clone().unwrap_or("main".to_string());
1131 let api = api.repo(Repo::with_revision(
1132 model_id_str.clone(),
1133 RepoType::Model,
1134 revision.clone(),
1135 ));
1136
1137 let mut gate_filenames = vec![];
1138 for rfilename in
1139 api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1140 {
1141 gate_filenames.push(api_get_file!(api, &rfilename, model_id));
1142 }
1143 assert_eq!(
1144 gate_filenames.len(),
1145 1,
1146 "Gate model ID must contain only one .safetensors file"
1147 );
1148
1149 let vb = from_mmaped_safetensors(
1150 gate_filenames.clone(),
1151 vec![],
1152 Some(dtype),
1153 dev,
1154 vec![None],
1155 silent,
1156 None,
1157 |_| true,
1158 Arc::new(|_| DeviceForLoadTensor::Base),
1159 )?;
1160 info!(
1161 "Loaded gating layers from `{}`",
1162 gate_filenames[0].display()
1163 );
1164 Some(vb)
1165 } else {
1166 None
1167 };
1168
1169 self.model
1170 .create_anymoe_layers(vbs, config, (prefix, mlp), layers, expert_type, gate_vb)
1171 }
1172 fn amoe_supported(&self) -> bool {
1173 self.model.amoe_supported()
1174 }
1175}