1use super::isq::ImatrixDataSource;
2use super::isq::UqffFullSer;
3use super::{
4 get_model_paths, get_xlora_paths, AdapterKind, AnyMoePipelineMixin, AutoVisionLoader,
5 CacheManager, CacheManagerMixin, EitherCache, ForwardInputsResult, Gemma3Loader,
6 GeneralMetadata, IsqPipelineMixin, Loader, MetadataMixin, MiniCpmOLoader, ModelCategory,
7 ModelKind, ModelPaths, MultimodalPromptPrefixer, Phi4MMLoader, PreProcessingMixin, Processor,
8 Qwen2VLLoader, Qwen3VLLoader, TokenSource, VLlama4Loader, VLlamaLoader, VisionModel,
9 VisionModelLoader,
10};
11use super::{
12 Gemma3nLoader, Idefics2Loader, Idefics3Loader, LLaVALoader, LLaVANextLoader, Mistral3Loader,
13 Phi3VLoader, Qwen2_5VLLoader, VisionLoaderType,
14};
15use crate::attention::ATTENTION_CHUNK_SIZE;
16use crate::device_map::{self, DeviceMapper};
17use crate::distributed::{self, WorkerTransferData};
18use crate::kv_cache::{FullCacheManager, NormalCacheManager};
19use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine};
20use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
21use crate::pipeline::llg::build_llg_factory;
22use crate::pipeline::loaders::auto_device_map;
23use crate::pipeline::loaders::QuantizationConfigShim;
24use crate::pipeline::sampling::sample_and_add_toks;
25use crate::pipeline::text_models_inputs_processor::make_prompt_chunk;
26use crate::pipeline::{get_chat_template, ChatTemplate, IsqOrganization, LocalModelPaths};
27use crate::prefix_cacher::PrefixCacheManagerV2;
28use crate::sequence::Sequence;
29use crate::utils::tokenizer::get_tokenizer;
30use crate::utils::varbuilder_utils::DeviceForLoadTensor;
31use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
32use crate::vision_models::preprocessor_config::PreProcessorConfig;
33use crate::vision_models::processor_config::ProcessorConfig;
34use crate::vision_models::ModelInputs;
35use crate::{
36 api_dir_list, api_get_file, get_paths, get_uqff_paths, vision_normal_model_loader,
37 vision_normal_model_loader_sharded, AnyMoeExpertType, DeviceMapSetting, Ordering,
38 PagedAttentionConfig, Pipeline, Topology, TryIntoDType, GLOBAL_HF_CACHE,
39};
40use anyhow::Result;
41use candle_core::{Device, Tensor, Var};
42use hf_hub::Cache;
43use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
44use indicatif::MultiProgress;
45use mistralrs_quant::log::once_log_info;
46use mistralrs_quant::{
47 AfqLayer, GgufMatMul, HqqLayer, ImmediateIsqOverride, IsqType, QuantizedSerdeType,
48};
49use rand_isaac::Isaac64Rng;
50use regex_automata::meta::Regex;
51use std::any::Any;
52use std::borrow::Cow;
53use std::path::{Path, PathBuf};
54use std::str::FromStr;
55use std::sync::{Arc, RwLock};
56use std::time::Instant;
57use std::{env, fs};
58use tokenizers::Tokenizer;
59use tokio::sync::Mutex;
60use tracing::{info, warn};
61
62pub struct VisionPipeline {
63 model: Box<dyn VisionModel + Send + Sync>,
64 tokenizer: Arc<Tokenizer>,
65 chat_template: Arc<ChatTemplate>,
66 model_id: String,
67 metadata: Arc<GeneralMetadata>,
68 processor: Arc<dyn Processor + Send + Sync>,
69 preprocessor_config: Arc<PreProcessorConfig>,
70 topology: Option<Topology>,
71 silent: bool,
72 prefixer: Arc<dyn MultimodalPromptPrefixer>,
73 mapper: Box<dyn DeviceMapper + Send + Sync>,
74
75 template_filename: Option<PathBuf>,
77 generation_config: Option<PathBuf>,
78 config: String,
79 processor_filename: Option<PathBuf>,
80 preprocessor_filename: Option<PathBuf>,
81 imatrix: Option<PathBuf>,
82}
83
84pub struct VisionLoader {
86 inner: Box<dyn VisionModelLoader>,
87 model_id: String,
88 config: VisionSpecificConfig,
89 kind: ModelKind,
90 chat_template: Option<String>,
91 tokenizer_json: Option<String>,
92 xlora_model_id: Option<String>,
93 xlora_order: Option<Ordering>,
94 token_source: RwLock<Option<TokenSource>>,
95 revision: RwLock<Option<String>>,
96 from_uqff: RwLock<Option<Vec<PathBuf>>>,
97 jinja_explicit: Option<String>,
98 hf_cache_path: Option<PathBuf>,
99 lora_adapter_ids: Option<Vec<String>>,
100}
101
102#[derive(Default)]
103pub struct VisionLoaderBuilder {
105 model_id: Option<String>,
106 config: VisionSpecificConfig,
107 kind: ModelKind,
108 chat_template: Option<String>,
109 tokenizer_json: Option<String>,
110 jinja_explicit: Option<String>,
111 hf_cache_path: Option<PathBuf>,
112 lora_adapter_ids: Option<Vec<String>>,
113}
114
115#[derive(Clone, Default)]
116pub struct VisionSpecificConfig {
118 pub topology: Option<Topology>,
119 pub write_uqff: Option<PathBuf>,
120 pub from_uqff: Option<Vec<PathBuf>>,
121 pub max_edge: Option<u32>,
122 pub imatrix: Option<PathBuf>,
123 pub calibration_file: Option<PathBuf>,
124 pub hf_cache_path: Option<PathBuf>,
125 pub matformer_config_path: Option<PathBuf>,
126 pub matformer_slice_name: Option<String>,
127}
128
129impl VisionLoaderBuilder {
130 pub fn new(
131 config: VisionSpecificConfig,
132 chat_template: Option<String>,
133 tokenizer_json: Option<String>,
134 model_id: Option<String>,
135 jinja_explicit: Option<String>,
136 ) -> Self {
137 Self {
138 config,
139 chat_template,
140 tokenizer_json,
141 model_id,
142 jinja_explicit,
143 kind: ModelKind::Normal,
144 hf_cache_path: None,
145 ..Default::default()
146 }
147 }
148
149 pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
150 self.hf_cache_path = Some(hf_cache_path);
151 self
152 }
153
154 pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
155 self.kind = ModelKind::Adapter {
156 adapter: AdapterKind::Lora,
157 };
158 self.lora_adapter_ids = Some(lora_adapter_ids);
159 self
160 }
161
162 pub fn build(self, loader: Option<VisionLoaderType>) -> Box<dyn Loader> {
163 let loader: Box<dyn VisionModelLoader> = match loader {
164 Some(VisionLoaderType::Phi3V) => Box::new(Phi3VLoader),
165 Some(VisionLoaderType::Idefics2) => Box::new(Idefics2Loader),
166 Some(VisionLoaderType::LLaVANext) => Box::new(LLaVANextLoader),
167 Some(VisionLoaderType::LLaVA) => Box::new(LLaVALoader),
168 Some(VisionLoaderType::VLlama) => Box::new(VLlamaLoader),
169 Some(VisionLoaderType::Qwen2VL) => Box::new(Qwen2VLLoader),
170 Some(VisionLoaderType::Idefics3) => Box::new(Idefics3Loader),
171 Some(VisionLoaderType::MiniCpmO) => Box::new(MiniCpmOLoader),
172 Some(VisionLoaderType::Phi4MM) => Box::new(Phi4MMLoader),
173 Some(VisionLoaderType::Qwen2_5VL) => Box::new(Qwen2_5VLLoader),
174 Some(VisionLoaderType::Gemma3) => Box::new(Gemma3Loader),
175 Some(VisionLoaderType::Mistral3) => Box::new(Mistral3Loader),
176 Some(VisionLoaderType::Llama4) => Box::new(VLlama4Loader),
177 Some(VisionLoaderType::Gemma3n) => Box::new(Gemma3nLoader),
178 Some(VisionLoaderType::Qwen3VL) => Box::new(Qwen3VLLoader),
179 None => Box::new(AutoVisionLoader),
180 };
181 Box::new(VisionLoader {
182 inner: loader,
183 model_id: self.model_id.unwrap(),
184 config: self.config,
185 kind: self.kind,
186 chat_template: self.chat_template,
187 tokenizer_json: self.tokenizer_json,
188 xlora_model_id: None,
189 xlora_order: None,
190 jinja_explicit: self.jinja_explicit,
191 token_source: RwLock::new(None),
192 revision: RwLock::new(None),
193 from_uqff: RwLock::new(None),
194 hf_cache_path: self.hf_cache_path,
195 lora_adapter_ids: self.lora_adapter_ids,
196 })
197 }
198}
199
200impl Loader for VisionLoader {
201 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
202 fn load_model_from_hf(
203 &self,
204 revision: Option<String>,
205 token_source: TokenSource,
206 dtype: &dyn TryIntoDType,
207 device: &Device,
208 silent: bool,
209 mapper: DeviceMapSetting,
210 in_situ_quant: Option<IsqType>,
211 paged_attn_config: Option<PagedAttentionConfig>,
212 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
213 let cache = self
214 .hf_cache_path
215 .clone()
216 .map(Cache::new)
217 .unwrap_or_default();
218 GLOBAL_HF_CACHE.get_or_init(|| cache);
219
220 let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths!(
221 LocalModelPaths,
222 &token_source,
223 revision.clone(),
224 self,
225 None,
226 None,
227 silent,
228 self.config.from_uqff.is_some()
229 );
230 if let Some(from_uqff) = self.config.from_uqff.clone() {
231 *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
232 }
233 *self
234 .token_source
235 .write()
236 .expect("Failed to write to token source") = Some(token_source);
237 *self.revision.write().expect("Failed to write to revision") = revision;
238 self.load_model_from_path(
239 &paths?,
240 dtype,
241 device,
242 silent,
243 mapper,
244 in_situ_quant,
245 paged_attn_config,
246 )
247 }
248
249 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
250 fn load_model_from_path(
251 &self,
252 paths: &Box<dyn ModelPaths>,
253 dtype: &dyn TryIntoDType,
254 device: &Device,
255 silent: bool,
256 mut mapper: DeviceMapSetting,
257 mut in_situ_quant: Option<IsqType>,
258 mut paged_attn_config: Option<PagedAttentionConfig>,
259 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
260 let config = std::fs::read_to_string(paths.get_config_filename())?;
261
262 if !self.inner.supports_paged_attention(&config) {
263 paged_attn_config = None;
264 }
265
266 info!("Prompt chunk size is {ATTENTION_CHUNK_SIZE}.");
267
268 let use_nccl = mistralrs_quant::distributed::use_nccl();
269
270 let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
271 let payload: WorkerTransferData = serde_json::from_str(&payload)?;
272 let WorkerTransferData::Init { id: _, worker_rank } = payload;
273 vec![candle_core::Device::new_cuda_with_stream(worker_rank + 1)?]
274 } else if use_nccl {
275 vec![candle_core::Device::new_cuda_with_stream(0)?]
276 } else {
277 device_map::get_all_similar_devices(device)?
278 };
279 #[cfg(feature = "cuda")]
280 for device in &available_devices {
281 if let Device::Cuda(dev) = device {
282 unsafe { dev.disable_event_tracking() };
283 }
284 }
285 let device = if use_nccl {
286 available_devices[0].clone()
287 } else {
288 device.clone()
289 };
290
291 let matformer_slicing_config = if let Some(matformer_path) =
293 &self.config.matformer_config_path
294 {
295 use crate::matformer::{MatformerConfig, MatformerSliceConfig};
296 info!("Loading Matformer config from {:?}", matformer_path);
297 let config = Arc::new(MatformerConfig::from_file(matformer_path)?);
298
299 if let Some(slice_name) = &self.config.matformer_slice_name {
300 info!("Using Matformer slice: {}", slice_name);
301 Some(MatformerSliceConfig::new(slice_name.clone(), config))
302 } else {
303 warn!("Matformer config loaded but no slice name specified. Models will use their default slice.");
306 None
307 }
308 } else {
309 None
310 };
311
312 if use_nccl {
314 mapper = DeviceMapSetting::DummyNccl {
315 nm_device: available_devices[0].clone(),
316 };
317 } else if let DeviceMapSetting::Auto(mut params) = mapper.clone() {
318 params = params.maybe_promote_to_vision();
320
321 let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
323
324 if QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)? != 1 {
326 in_situ_quant = None;
327 }
328
329 let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
332 if let Some(serialized) = &*self.from_uqff.read().unwrap() {
333 let weight_pack_factor = {
334 let ser_artifacts = unsafe {
335 candle_core::safetensors::MmapedSafetensors::multi(serialized)?
336 };
337 let mut total_pack_factors = 0;
338 let total_tensors = ser_artifacts.tensors().len();
339 for (_, artifact) in ser_artifacts.tensors() {
340 let artifact = artifact.data();
341 let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
343 let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
344 {
345 QuantizedSerdeType::Hqq => {
346 HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
347 .pack_factor(dtype)
348 }
349 QuantizedSerdeType::Gguf => {
350 GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
351 .pack_factor(dtype)
352 }
353 QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
354 QuantizedSerdeType::Unquant => 1,
355 QuantizedSerdeType::Afq => {
356 AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
357 .pack_factor(dtype)
358 }
359 };
360 total_pack_factors += pack_factor;
361 }
362
363 total_pack_factors / total_tensors
364 };
365
366 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
367 &config,
368 dtype,
369 weight_pack_factor,
370 matformer_slicing_config.as_ref(),
371 )?;
372 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
373 &config,
374 dtype,
375 weight_pack_factor,
376 matformer_slicing_config.as_ref(),
377 )?;
378 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
379 (
380 layer_sizes_in_bytes,
381 non_mapped_size_in_bytes,
382 layer_sizes_sum + non_mapped_size_in_bytes,
383 )
384 } else if let Some(isq) = in_situ_quant {
385 let weight_pack_factor = isq.pack_factor(dtype);
386 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
387 &config,
388 dtype,
389 weight_pack_factor,
390 matformer_slicing_config.as_ref(),
391 )?;
392 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
393 &config,
394 dtype,
395 weight_pack_factor,
396 matformer_slicing_config.as_ref(),
397 )?;
398 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
399 (
400 layer_sizes_in_bytes,
401 non_mapped_size_in_bytes,
402 layer_sizes_sum + non_mapped_size_in_bytes,
403 )
404 } else {
405 let weight_pack_factor =
407 QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)?;
408 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
409 &config,
410 dtype,
411 weight_pack_factor,
412 matformer_slicing_config.as_ref(),
413 )?;
414 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
415 &config,
416 dtype,
417 weight_pack_factor,
418 matformer_slicing_config.as_ref(),
419 )?;
420 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
421 (
422 layer_sizes_in_bytes,
423 non_mapped_size_in_bytes,
424 layer_sizes_sum + non_mapped_size_in_bytes,
425 )
426 };
427
428 let new = auto_device_map::get_device_layers(
429 &*self.inner,
430 &config,
431 self.inner.num_layers(&config)?,
432 layer_sizes_in_bytes,
433 non_mapped_size_in_bytes,
434 total_model_size_in_bytes,
435 &available_devices,
436 dtype,
437 ¶ms,
438 paged_attn_config.as_ref(),
439 )?;
440 mapper = DeviceMapSetting::Map(new);
441 }
442
443 let pipeline_mapper = mapper.into_mapper(
444 self.inner.num_layers(&config)?,
445 &device,
446 self.config.topology.as_ref(),
447 )?;
448 let mapper = mapper.into_mapper(
449 self.inner.num_layers(&config)?,
450 &device,
451 self.config.topology.as_ref(),
452 )?;
453 let mut layer_devices = Vec::new();
454 for layer in 0..self.inner.num_layers(&config)? {
455 let device = mapper.device_for(layer, false).cloned();
456 layer_devices.push(device);
457 }
458 let dtype = mapper.get_min_dtype(dtype)?;
459
460 let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
463 if mapping_uses_cpu && paged_attn_config.is_some() {
464 warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
465 paged_attn_config = None;
466 }
467
468 info!("Model config: {:?}", self.inner.get_config_repr(&config)?);
469 if crate::using_flash_attn() {
470 once_log_info("FlashAttention is enabled.");
471 }
472
473 let topology_overrides = self
474 .config
475 .topology
476 .as_ref()
477 .map(|topology| {
478 topology
479 .pattern_overrides()
480 .into_iter()
481 .map(|(regex, layer)| ImmediateIsqOverride {
482 predicate: regex,
483 ty: layer.isq,
484 device: layer.device.clone(),
485 })
486 .collect::<Vec<_>>()
487 })
488 .unwrap_or_default();
489 let has_override_isq = topology_overrides
490 .iter()
491 .any(|override_entry| override_entry.ty.is_some());
492 let topology_requires_post_quant = self
493 .config
494 .topology
495 .as_ref()
496 .is_some_and(|topology| topology.requires_post_quantization());
497
498 let allow_immediate_cli = self.config.imatrix.is_none()
499 && self.config.calibration_file.is_none()
500 && !device.is_cuda()
501 && in_situ_quant.is_some();
502
503 let mut immediate_ty = None;
504 let mut immediate_predicates = Vec::new();
505 if allow_immediate_cli {
506 immediate_ty = in_situ_quant;
507 immediate_predicates = self.inner.immediate_isq_predicates(&config)?;
508 info!("Applying ISQ to {in_situ_quant:?}");
509 if immediate_predicates.is_empty() {
510 warn!("No predicates for this model and ISQ setting detected. ISQ will not be applied to any weights!");
511 }
512 }
513
514 let use_immediate = allow_immediate_cli || has_override_isq;
515 if use_immediate {
516 mistralrs_quant::set_immediate_isq_with_overrides(
517 immediate_ty,
518 immediate_predicates.clone(),
519 topology_overrides.clone(),
520 );
521 }
522
523 let mut loading_isq = if use_immediate {
525 false
526 } else {
527 in_situ_quant.is_some()
528 };
529 if self.config.imatrix.is_some() || self.config.calibration_file.is_some() {
530 loading_isq = true;
531 }
532 loading_isq |= topology_requires_post_quant;
533
534 if self.config.imatrix.is_some() && self.config.calibration_file.is_some() {
535 anyhow::bail!(
536 "`imatrix` and `calibration_file` were both specified, this is not allowed."
537 );
538 }
539
540 let load_device = if !loading_isq || self.config.calibration_file.is_some() {
542 loading_isq = false;
543 device.clone()
544 } else {
545 Device::Cpu
546 };
547
548 let attention_mechanism = if paged_attn_config.is_some() {
549 AttentionImplementation::PagedAttention
550 } else {
551 AttentionImplementation::Eager
552 };
553
554 let multi_progress = Arc::new(MultiProgress::new());
555
556 let mut model = if use_nccl {
557 let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
558 dtype,
559 &device,
560 &available_devices,
561 silent,
562 &config,
563 loading_isq,
564 self.config.from_uqff.is_some(),
565 IsqOrganization::Default,
566 &*self.inner,
567 paths.as_ref(),
568 )?;
569
570 match self.kind {
572 ModelKind::Normal => vision_normal_model_loader_sharded!(
573 sharded_vb,
574 config,
575 self.inner,
576 mapper,
577 loading_isq,
578 device.clone(),
579 attention_mechanism,
580 multi_progress.clone(),
581 matformer_slicing_config.clone(),
582 ),
583 _ => unreachable!(),
584 }
585 } else {
586 match self.kind {
587 ModelKind::Normal => vision_normal_model_loader!(
588 paths,
589 Some(dtype),
590 &load_device,
591 layer_devices.clone(),
592 config,
593 self.inner,
594 silent,
595 mapper,
596 loading_isq,
597 self.config.from_uqff.is_some(),
598 device.clone(),
599 attention_mechanism,
600 multi_progress,
601 matformer_slicing_config.clone(),
602 ),
603 _ => unreachable!(),
604 }
605 };
606
607 let preprocessor_config: PreProcessorConfig = match paths.get_preprocessor_config().as_ref()
609 {
610 Some(preprocessor_config) => {
611 serde_json::from_str(&fs::read_to_string(preprocessor_config).unwrap()).unwrap()
612 }
613 None => PreProcessorConfig::default(),
614 };
615 let processor_config: Option<ProcessorConfig> = paths
616 .get_processor_config()
617 .as_ref()
618 .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
619
620 let processor = self.inner.get_processor(
621 &config,
622 processor_config,
623 preprocessor_config.clone(),
624 self.config.max_edge,
625 ); let tokenizer = get_tokenizer(
628 paths.get_tokenizer_filename(),
629 Some(processor.get_special_tokens()),
630 )?;
631
632 let gen_conf: Option<GenerationConfig> = paths
633 .get_gen_conf_filename()
634 .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
635 let chat_template_explicit = paths
636 .get_chat_template_explicit()
637 .as_ref()
638 .map(|x| x.to_string_lossy().to_string());
639 let chat_template = get_chat_template(
640 paths,
641 self.jinja_explicit.as_ref(),
642 chat_template_explicit.as_ref(),
643 self.chat_template.as_ref(),
644 None,
645 );
646
647 if let Some(calibration_file) = &self.config.calibration_file {
648 let calibration_data = std::fs::read_to_string(calibration_file)?;
649 let tokens = tokenizer
651 .encode_fast(calibration_data, false)
652 .map_err(anyhow::Error::msg)?
653 .get_ids()
654 .to_vec();
655 info!(
656 "Collecting imatrix from calibration file `{}` of {} tokens.",
657 calibration_file.display(),
658 tokens.len()
659 );
660 let bos_toks = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();
661 let bos_tok_id = tokenizer
662 .token_to_id(&bos_toks[0])
663 .expect("Somehow the bos token is not present.");
664
665 model.begin_track_stats()?;
668
669 const CHUNK_SIZE: usize = 1024;
670 let n_chunks: usize = tokens.len().div_ceil(CHUNK_SIZE);
671 let start = Instant::now();
672 for (i, chunk) in tokens.chunks(CHUNK_SIZE).enumerate() {
673 let chunk = [vec![bos_tok_id], chunk.to_vec()].concat();
674 let chunk_len = chunk.len();
675
676 let start = Instant::now();
677 let inputs = make_prompt_chunk(
678 0,
679 vec![&chunk],
680 &[0],
681 &load_device,
682 None,
683 false,
684 None,
685 None,
686 )?;
687 let _ = model.forward(
688 &inputs.input,
689 None, &inputs.positions,
691 inputs.context_lens,
692 inputs.position_ids,
693 model.default_model_specific_args(&inputs.input),
694 None,
695 &inputs.flash_meta,
696 )?;
697 match model.cache_mut() {
698 EitherCache::Full(full) => {
699 for layer in &mut *full.lock() {
700 *layer = None
701 }
702 }
703 EitherCache::Normal(normal) => {
704 for layer in &mut *normal.lock().unwrap().0 {
705 layer.reset();
706 }
707 }
708 }
709 let end = Instant::now();
710 info!(
711 "Processed chunk {}/{n_chunks} ({chunk_len} tokens), {:.2}s",
712 i + 1,
713 end.duration_since(start).as_secs_f32()
714 );
715 }
716 load_device.synchronize()?;
717 let end = Instant::now();
718 info!(
719 "Finished collecting imatrix in {:.2}s",
720 end.duration_since(start).as_secs_f32()
721 );
722 }
723
724 let should_serialize = self.config.write_uqff.is_some();
725 let should_quantize_pass = loading_isq;
726
727 if (should_quantize_pass || should_serialize) && self.config.from_uqff.is_none() {
728 let imatrix_source = if should_quantize_pass {
729 match (
730 self.config.imatrix.as_ref(),
731 self.config.calibration_file.is_some(),
732 ) {
733 (None, false) => None,
734 (Some(file), false) => Some(ImatrixDataSource::File(file)),
735 (None, true) => Some(ImatrixDataSource::Collected),
736 (Some(_), true) => unreachable!(),
737 }
738 } else {
739 None
740 };
741 if should_quantize_pass {
742 info!("Applying ISQ to all ranks.");
743 } else {
744 info!("Serializing existing ISQ tensors without additional quantization.");
745 }
746 model.quantize(
747 in_situ_quant,
748 device.clone(),
749 self.config.topology.as_ref(),
750 silent,
751 imatrix_source,
752 IsqOrganization::Default,
753 should_quantize_pass,
754 self.config.write_uqff.as_ref(),
755 UqffFullSer {
756 tokenizer: &tokenizer,
757 template_filename: paths.get_template_filename(),
758 generation_config: paths.get_gen_conf_filename(),
759 config: config.clone(),
760 processor_filename: paths.get_processor_config(),
761 preprocessor_filename: paths.get_preprocessor_config(),
762 modules: None,
763 module_paths: None,
764 },
765 Arc::new(MultiProgress::new()),
766 )?;
767 } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
768 model.load_from_artifacts(
769 device.clone(),
770 self.config.topology.as_ref(),
771 silent,
772 from_uqff,
773 )?;
774 }
775
776 let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
777 anyhow::ensure!(
778 !matches!(self.kind, ModelKind::Adapter { .. }),
779 "PagedAttention does not support adapter models."
780 );
781 let cache_config = calculate_cache_config(
782 paged_attn_config.mem_gpu,
783 paged_attn_config.block_size,
784 dtype,
785 paged_attn_config.cache_type,
786 model.config(),
787 &device,
788 &layer_devices,
789 silent,
790 )?;
791 let cache_engine =
792 CacheEngine::new(model.config(), &cache_config, dtype, &device, layer_devices)?;
793 (Some(cache_config), Some(cache_engine))
794 } else {
795 (None, None)
796 };
797
798 let max_seq_len = model.max_seq_len();
799 let llg_factory = build_llg_factory(tokenizer.clone())?;
800 let num_hidden_layers = match model.cache() {
801 EitherCache::Full(full) => full.lock().len(),
802 EitherCache::Normal(normal) => normal.lock().unwrap().0.len(),
803 };
804 let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
805 let sliding_window = model.config().sliding_window;
806 let model_metadata = Arc::new(model.config().clone());
807 Ok(Arc::new(Mutex::new(VisionPipeline {
808 model,
809 tokenizer: tokenizer.into(),
810 chat_template: Arc::new(chat_template),
811 model_id: self.model_id.clone(),
812 metadata: Arc::new(GeneralMetadata {
813 max_seq_len,
814 llg_factory: Some(llg_factory),
815 is_xlora: false,
816 num_hidden_layers,
817 eos_tok: eos,
818 kind: self.kind.clone(),
819 no_kv_cache: false,
820 no_prefix_cache: !self.inner.supports_prefix_cacher(&config),
821 activation_dtype: dtype,
822 sliding_window,
823 cache_config,
824 cache_engine,
825 model_metadata: Some(model_metadata),
826 modalities: self.inner.modalities(&config)?,
827 }),
828 processor,
829 prefixer: self.inner.prefixer(&config),
830 preprocessor_config: Arc::new(preprocessor_config),
831 topology: self.config.topology.clone(),
832 silent,
833 template_filename: paths.get_template_filename().clone(),
834 generation_config: paths.get_gen_conf_filename().cloned(),
835 config,
836 processor_filename: paths.get_processor_config().clone(),
837 preprocessor_filename: paths.get_preprocessor_config().clone(),
838 mapper: pipeline_mapper,
839 imatrix: self.config.imatrix.clone(),
840 })))
841 }
842
843 fn get_id(&self) -> String {
844 self.model_id.to_string()
845 }
846
847 fn get_kind(&self) -> ModelKind {
848 self.kind.clone()
849 }
850}
851
852impl PreProcessingMixin for VisionPipeline {
853 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
854 Some(self.chat_template.clone())
855 }
856 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
857 Some(self.preprocessor_config.clone())
858 }
859 fn get_processor(&self) -> Arc<dyn super::Processor> {
860 self.processor.clone()
861 }
862}
863
864impl IsqPipelineMixin for VisionPipeline {
865 fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
866 let device = self.device().clone();
867 self.model
868 .quantize(
869 Some(dtype),
870 device,
871 self.topology.as_ref(),
872 self.silent,
873 self.imatrix.as_ref().map(ImatrixDataSource::File),
874 IsqOrganization::Default,
875 true,
876 None,
877 UqffFullSer {
878 tokenizer: &self.tokenizer,
879 template_filename: &self.template_filename,
880 generation_config: self.generation_config.as_ref(),
881 config: self.config.clone(),
882 processor_filename: &self.processor_filename,
883 preprocessor_filename: &self.preprocessor_filename,
884 modules: None,
885 module_paths: None,
886 },
887 Arc::new(MultiProgress::new()),
888 )
889 .map_err(anyhow::Error::msg)
890 }
891}
892
893impl CacheManagerMixin for VisionPipeline {
894 fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
895 if matches!(self.model.cache(), EitherCache::Full(_)) {
896 FullCacheManager.clone_in_cache(self, seqs, false)
897 } else {
898 NormalCacheManager.clone_in_cache(self, seqs, false)
899 }
900 }
901 fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
902 if matches!(self.model.cache(), EitherCache::Full(_)) {
903 FullCacheManager.clone_out_cache(self, seqs, false)
904 } else {
905 NormalCacheManager.clone_out_cache(self, seqs, false)
906 }
907 }
908 fn set_none_cache(
909 &self,
910 seqs: &mut [&mut Sequence],
911 reset_non_granular: bool,
912 modify_draft_cache: bool,
913
914 load_preallocated_cache: bool,
915 ) {
916 if matches!(self.model.cache(), EitherCache::Full(_)) {
917 FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false);
918 } else {
919 NormalCacheManager.set_none_cache(
920 self,
921 seqs,
922 modify_draft_cache,
923 load_preallocated_cache,
924 );
925 }
926 if reset_non_granular {
927 self.reset_non_granular_state()
928 }
929 }
930 fn cache(&self) -> &EitherCache {
931 self.model.cache()
932 }
933}
934
935impl MetadataMixin for VisionPipeline {
936 fn device(&self) -> Device {
937 self.model.device().clone()
938 }
939 fn get_metadata(&self) -> Arc<GeneralMetadata> {
940 self.metadata.clone()
941 }
942 fn name(&self) -> String {
943 self.model_id.clone()
944 }
945 fn reset_non_granular_state(&self) {}
946 fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
947 Some(self.tokenizer.clone())
948 }
949 fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
950 Some(&*self.mapper)
951 }
952}
953
954#[async_trait::async_trait]
955impl Pipeline for VisionPipeline {
956 fn forward_inputs(
957 &mut self,
958 inputs: Box<dyn Any>,
959 return_raw_logits: bool,
960 ) -> candle_core::Result<ForwardInputsResult> {
961 let ModelInputs {
962 input_ids,
963 seqlen_offsets,
964 context_lens,
965 position_ids,
966 pixel_values,
967 model_specific_args,
968 paged_attn_meta,
969 flash_meta,
970 } = *inputs.downcast::<ModelInputs>().expect("Downcast failed.");
971 let metadata = self.get_metadata();
972 let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
973 (Some(engine), Some(meta)) => Some((engine.get_kv_cache().clone(), meta)),
974 (Some(_), None) => {
975 candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
977 }
978 (None, Some(_)) => {
979 candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
981 }
982 (None, None) => None,
983 };
984 let logits = self.model.forward(
985 &input_ids,
986 pixel_values,
987 &seqlen_offsets,
988 context_lens,
989 position_ids,
990 model_specific_args,
991 paged_attn_meta,
992 &flash_meta,
993 )?;
994 if return_raw_logits {
995 Ok(ForwardInputsResult::RawLogits { logits })
996 } else {
997 Ok(ForwardInputsResult::CausalGeneration { logits })
998 }
999 }
1000 async fn sample_causal_gen(
1001 &self,
1002 seqs: &mut [&mut Sequence],
1003 logits: Vec<Tensor>,
1004 prefix_cacher: &mut PrefixCacheManagerV2,
1005 disable_eos_stop: bool,
1006 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
1007 ) -> Result<(), candle_core::Error> {
1008 sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
1009 }
1010 fn category(&self) -> ModelCategory {
1011 ModelCategory::Vision {
1012 prefixer: self.prefixer.clone(),
1013 }
1014 }
1015}
1016
1017impl AnyMoePipelineMixin for VisionPipeline {
1018 fn amoe_finish_training(&mut self, gate_model_id: Option<String>) -> candle_core::Result<()> {
1019 self.model.finish_training(gate_model_id)
1020 }
1021 fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
1022 self.model.get_vars()
1023 }
1024 fn amoe_base_model_trainable_params(&self) -> usize {
1025 self.model.trainable_params()
1026 }
1027 fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
1028 self.model.take_cached_gating_outputs()
1029 }
1030 fn amoe_create_layers(
1031 &mut self,
1032 model_ids: Vec<String>,
1033 token: &TokenSource,
1034 revision: Option<String>,
1035 match_regex: &str,
1036 config: crate::amoe::AnyMoeConfig,
1037 dtype: candle_core::DType,
1038 dev: &Device,
1039 (prefix, mlp): (String, String),
1040 layers: Vec<usize>,
1041 expert_type: AnyMoeExpertType,
1042 silent: bool,
1043 gate_model_id: Option<String>,
1044 ) -> candle_core::Result<()> {
1045 let mut vbs = Vec::new();
1046 let regex = Regex::new(match_regex).map_err(candle_core::Error::msg)?;
1048 for model_id in model_ids {
1049 let model_id_str = &model_id;
1050 let model_id = Path::new(&model_id);
1051
1052 let api = {
1053 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1054 let mut api = ApiBuilder::from_cache(cache)
1055 .with_progress(!silent)
1056 .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1057 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1058 api = api.with_cache_dir(x.into());
1059 }
1060 api.build().map_err(candle_core::Error::msg)?
1061 };
1062 let revision = revision.clone().unwrap_or("main".to_string());
1063 let api = api.repo(Repo::with_revision(
1064 model_id_str.clone(),
1065 RepoType::Model,
1066 revision.clone(),
1067 ));
1068
1069 let mut filenames = vec![];
1070 for rfilename in
1071 api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1072 {
1073 filenames.push(api_get_file!(api, &rfilename, model_id));
1074 }
1075
1076 let regex = regex.clone();
1077 let match_regex_clone = match_regex.to_string();
1078 let layers_clone = layers.clone();
1079 let vb = from_mmaped_safetensors(
1080 filenames,
1081 vec![],
1082 Some(dtype),
1083 dev,
1084 vec![None],
1085 silent,
1086 None,
1087 move |key| {
1088 if regex.is_match(&key) {
1089 let last_layer_idx = key.find(&match_regex_clone).unwrap() - 1;
1092 let first_layer_idx = key[..last_layer_idx].rfind('.').unwrap();
1093 let layer_n = key[first_layer_idx + 1..last_layer_idx]
1094 .parse::<usize>()
1095 .unwrap();
1096 layers_clone.contains(&layer_n) || layers_clone.is_empty()
1097 } else {
1098 false
1099 }
1100 },
1101 Arc::new(|_| DeviceForLoadTensor::Base),
1102 )?;
1103 vbs.push(vb);
1104 }
1105
1106 let gate_vb = if let Some(gate_model_id) = gate_model_id {
1107 let model_id_str = &gate_model_id;
1108 let model_id = Path::new(&gate_model_id);
1109
1110 let api = {
1111 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1112 let mut api = ApiBuilder::from_cache(cache)
1113 .with_progress(!silent)
1114 .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1115 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1116 api = api.with_cache_dir(x.into());
1117 }
1118 api.build().map_err(candle_core::Error::msg)?
1119 };
1120 let revision = revision.clone().unwrap_or("main".to_string());
1121 let api = api.repo(Repo::with_revision(
1122 model_id_str.clone(),
1123 RepoType::Model,
1124 revision.clone(),
1125 ));
1126
1127 let mut gate_filenames = vec![];
1128 for rfilename in
1129 api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1130 {
1131 gate_filenames.push(api_get_file!(api, &rfilename, model_id));
1132 }
1133 assert_eq!(
1134 gate_filenames.len(),
1135 1,
1136 "Gate model ID must contain only one .safetensors file"
1137 );
1138
1139 let vb = from_mmaped_safetensors(
1140 gate_filenames.clone(),
1141 vec![],
1142 Some(dtype),
1143 dev,
1144 vec![None],
1145 silent,
1146 None,
1147 |_| true,
1148 Arc::new(|_| DeviceForLoadTensor::Base),
1149 )?;
1150 info!(
1151 "Loaded gating layers from `{}`",
1152 gate_filenames[0].display()
1153 );
1154 Some(vb)
1155 } else {
1156 None
1157 };
1158
1159 self.model
1160 .create_anymoe_layers(vbs, config, (prefix, mlp), layers, expert_type, gate_vb)
1161 }
1162 fn amoe_supported(&self) -> bool {
1163 self.model.amoe_supported()
1164 }
1165}