1use super::isq::ImatrixDataSource;
2use super::isq::UqffFullSer;
3use super::{
4 get_model_paths, get_xlora_paths, AdapterKind, AnyMoePipelineMixin, AutoVisionLoader,
5 CacheManager, CacheManagerMixin, EitherCache, ForwardInputsResult, Gemma3Loader,
6 GeneralMetadata, IsqPipelineMixin, Loader, MetadataMixin, MiniCpmOLoader, ModelCategory,
7 ModelKind, ModelPaths, MultimodalPromptPrefixer, Phi4MMLoader, PreProcessingMixin, Processor,
8 Qwen2VLLoader, Qwen3VLLoader, TokenSource, VLlama4Loader, VLlamaLoader, VisionModel,
9 VisionModelLoader,
10};
11use super::{
12 Gemma3nLoader, Idefics2Loader, Idefics3Loader, LLaVALoader, LLaVANextLoader, Mistral3Loader,
13 Phi3VLoader, Qwen2_5VLLoader, VisionLoaderType,
14};
15use crate::attention::ATTENTION_CHUNK_SIZE;
16use crate::device_map::{self, DeviceMapper};
17use crate::distributed::{self, WorkerTransferData};
18use crate::kv_cache::{FullCacheManager, NormalCacheManager};
19use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine};
20use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
21use crate::pipeline::llg::build_llg_factory;
22use crate::pipeline::loaders::auto_device_map;
23use crate::pipeline::loaders::QuantizationConfigShim;
24use crate::pipeline::sampling::sample_and_add_toks;
25use crate::pipeline::text_models_inputs_processor::make_prompt_chunk;
26use crate::pipeline::{get_chat_template, ChatTemplate, IsqOrganization, LocalModelPaths};
27use crate::prefix_cacher::PrefixCacheManagerV2;
28use crate::sequence::Sequence;
29use crate::utils::tokenizer::get_tokenizer;
30use crate::utils::varbuilder_utils::DeviceForLoadTensor;
31use crate::utils::{
32 progress::{new_multi_progress, ProgressScopeGuard},
33 tokens::get_token,
34 varbuilder_utils::from_mmaped_safetensors,
35};
36use crate::vision_models::preprocessor_config::PreProcessorConfig;
37use crate::vision_models::processor_config::ProcessorConfig;
38use crate::vision_models::ModelInputs;
39use crate::{
40 api_dir_list, api_get_file, get_paths, get_uqff_paths, vision_normal_model_loader,
41 vision_normal_model_loader_sharded, AnyMoeExpertType, DeviceMapSetting, Ordering,
42 PagedAttentionConfig, Pipeline, Topology, TryIntoDType, GLOBAL_HF_CACHE,
43};
44use anyhow::Result;
45use candle_core::{Device, Tensor, Var};
46use hf_hub::Cache;
47use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
48use mistralrs_quant::log::once_log_info;
49use mistralrs_quant::{
50 AfqLayer, GgufMatMul, HqqLayer, ImmediateIsqOverride, IsqType, QuantizedSerdeType,
51};
52use rand_isaac::Isaac64Rng;
53use regex_automata::meta::Regex;
54use std::any::Any;
55use std::borrow::Cow;
56use std::path::{Path, PathBuf};
57use std::str::FromStr;
58use std::sync::{Arc, RwLock};
59use std::time::Instant;
60use std::{env, fs};
61use tokenizers::Tokenizer;
62use tokio::sync::Mutex;
63use tracing::{info, warn};
64
65pub struct VisionPipeline {
66 model: Box<dyn VisionModel + Send + Sync>,
67 tokenizer: Arc<Tokenizer>,
68 chat_template: Arc<ChatTemplate>,
69 model_id: String,
70 metadata: Arc<GeneralMetadata>,
71 processor: Arc<dyn Processor + Send + Sync>,
72 preprocessor_config: Arc<PreProcessorConfig>,
73 topology: Option<Topology>,
74 silent: bool,
75 prefixer: Arc<dyn MultimodalPromptPrefixer>,
76 mapper: Box<dyn DeviceMapper + Send + Sync>,
77
78 template_filename: Option<PathBuf>,
80 generation_config: Option<PathBuf>,
81 config: String,
82 processor_filename: Option<PathBuf>,
83 preprocessor_filename: Option<PathBuf>,
84 imatrix: Option<PathBuf>,
85}
86
87pub struct VisionLoader {
89 inner: Box<dyn VisionModelLoader>,
90 model_id: String,
91 config: VisionSpecificConfig,
92 kind: ModelKind,
93 chat_template: Option<String>,
94 tokenizer_json: Option<String>,
95 xlora_model_id: Option<String>,
96 xlora_order: Option<Ordering>,
97 token_source: RwLock<Option<TokenSource>>,
98 revision: RwLock<Option<String>>,
99 from_uqff: RwLock<Option<Vec<PathBuf>>>,
100 jinja_explicit: Option<String>,
101 hf_cache_path: Option<PathBuf>,
102 lora_adapter_ids: Option<Vec<String>>,
103}
104
105#[derive(Default)]
106pub struct VisionLoaderBuilder {
108 model_id: Option<String>,
109 config: VisionSpecificConfig,
110 kind: ModelKind,
111 chat_template: Option<String>,
112 tokenizer_json: Option<String>,
113 jinja_explicit: Option<String>,
114 hf_cache_path: Option<PathBuf>,
115 lora_adapter_ids: Option<Vec<String>>,
116}
117
118#[derive(Clone, Default)]
119pub struct VisionSpecificConfig {
121 pub topology: Option<Topology>,
122 pub write_uqff: Option<PathBuf>,
123 pub from_uqff: Option<Vec<PathBuf>>,
124 pub max_edge: Option<u32>,
125 pub imatrix: Option<PathBuf>,
126 pub calibration_file: Option<PathBuf>,
127 pub hf_cache_path: Option<PathBuf>,
128 pub matformer_config_path: Option<PathBuf>,
129 pub matformer_slice_name: Option<String>,
130}
131
132impl VisionLoaderBuilder {
133 pub fn new(
134 config: VisionSpecificConfig,
135 chat_template: Option<String>,
136 tokenizer_json: Option<String>,
137 model_id: Option<String>,
138 jinja_explicit: Option<String>,
139 ) -> Self {
140 Self {
141 config,
142 chat_template,
143 tokenizer_json,
144 model_id,
145 jinja_explicit,
146 kind: ModelKind::Normal,
147 hf_cache_path: None,
148 ..Default::default()
149 }
150 }
151
152 pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
153 self.hf_cache_path = Some(hf_cache_path);
154 self
155 }
156
157 pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
158 self.kind = ModelKind::Adapter {
159 adapter: AdapterKind::Lora,
160 };
161 self.lora_adapter_ids = Some(lora_adapter_ids);
162 self
163 }
164
165 pub fn build(self, loader: Option<VisionLoaderType>) -> Box<dyn Loader> {
166 let loader: Box<dyn VisionModelLoader> = match loader {
167 Some(VisionLoaderType::Phi3V) => Box::new(Phi3VLoader),
168 Some(VisionLoaderType::Idefics2) => Box::new(Idefics2Loader),
169 Some(VisionLoaderType::LLaVANext) => Box::new(LLaVANextLoader),
170 Some(VisionLoaderType::LLaVA) => Box::new(LLaVALoader),
171 Some(VisionLoaderType::VLlama) => Box::new(VLlamaLoader),
172 Some(VisionLoaderType::Qwen2VL) => Box::new(Qwen2VLLoader),
173 Some(VisionLoaderType::Idefics3) => Box::new(Idefics3Loader),
174 Some(VisionLoaderType::MiniCpmO) => Box::new(MiniCpmOLoader),
175 Some(VisionLoaderType::Phi4MM) => Box::new(Phi4MMLoader),
176 Some(VisionLoaderType::Qwen2_5VL) => Box::new(Qwen2_5VLLoader),
177 Some(VisionLoaderType::Gemma3) => Box::new(Gemma3Loader),
178 Some(VisionLoaderType::Mistral3) => Box::new(Mistral3Loader),
179 Some(VisionLoaderType::Llama4) => Box::new(VLlama4Loader),
180 Some(VisionLoaderType::Gemma3n) => Box::new(Gemma3nLoader),
181 Some(VisionLoaderType::Qwen3VL) => Box::new(Qwen3VLLoader),
182 None => Box::new(AutoVisionLoader),
183 };
184 Box::new(VisionLoader {
185 inner: loader,
186 model_id: self.model_id.unwrap(),
187 config: self.config,
188 kind: self.kind,
189 chat_template: self.chat_template,
190 tokenizer_json: self.tokenizer_json,
191 xlora_model_id: None,
192 xlora_order: None,
193 jinja_explicit: self.jinja_explicit,
194 token_source: RwLock::new(None),
195 revision: RwLock::new(None),
196 from_uqff: RwLock::new(None),
197 hf_cache_path: self.hf_cache_path,
198 lora_adapter_ids: self.lora_adapter_ids,
199 })
200 }
201}
202
203impl Loader for VisionLoader {
204 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
205 fn load_model_from_hf(
206 &self,
207 revision: Option<String>,
208 token_source: TokenSource,
209 dtype: &dyn TryIntoDType,
210 device: &Device,
211 silent: bool,
212 mapper: DeviceMapSetting,
213 in_situ_quant: Option<IsqType>,
214 paged_attn_config: Option<PagedAttentionConfig>,
215 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
216 let _progress_guard = ProgressScopeGuard::new(silent);
217 let cache = self
218 .hf_cache_path
219 .clone()
220 .map(Cache::new)
221 .unwrap_or_default();
222 GLOBAL_HF_CACHE.get_or_init(|| cache);
223
224 let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths!(
225 LocalModelPaths,
226 &token_source,
227 revision.clone(),
228 self,
229 None,
230 None,
231 silent,
232 self.config.from_uqff.is_some()
233 );
234 if let Some(from_uqff) = self.config.from_uqff.clone() {
235 *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
236 }
237 *self
238 .token_source
239 .write()
240 .expect("Failed to write to token source") = Some(token_source);
241 *self.revision.write().expect("Failed to write to revision") = revision;
242 self.load_model_from_path(
243 &paths?,
244 dtype,
245 device,
246 silent,
247 mapper,
248 in_situ_quant,
249 paged_attn_config,
250 )
251 }
252
253 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
254 fn load_model_from_path(
255 &self,
256 paths: &Box<dyn ModelPaths>,
257 dtype: &dyn TryIntoDType,
258 device: &Device,
259 silent: bool,
260 mut mapper: DeviceMapSetting,
261 mut in_situ_quant: Option<IsqType>,
262 mut paged_attn_config: Option<PagedAttentionConfig>,
263 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
264 let _progress_guard = ProgressScopeGuard::new(silent);
265 let config = std::fs::read_to_string(paths.get_config_filename())?;
266
267 if !self.inner.supports_paged_attention(&config) {
268 paged_attn_config = None;
269 }
270
271 info!("Prompt chunk size is {ATTENTION_CHUNK_SIZE}.");
272
273 let use_nccl = mistralrs_quant::distributed::use_nccl();
274
275 let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
276 let payload: WorkerTransferData = serde_json::from_str(&payload)?;
277 let WorkerTransferData::Init { id: _, worker_rank } = payload;
278 vec![candle_core::Device::new_cuda_with_stream(worker_rank + 1)?]
279 } else if use_nccl {
280 vec![candle_core::Device::new_cuda_with_stream(0)?]
281 } else {
282 device_map::get_all_similar_devices(device)?
283 };
284 #[cfg(feature = "cuda")]
285 for device in &available_devices {
286 if let Device::Cuda(dev) = device {
287 unsafe { dev.disable_event_tracking() };
288 }
289 }
290 let device = if use_nccl {
291 available_devices[0].clone()
292 } else {
293 device.clone()
294 };
295
296 let matformer_slicing_config = if let Some(matformer_path) =
298 &self.config.matformer_config_path
299 {
300 use crate::matformer::{MatformerConfig, MatformerSliceConfig};
301 info!("Loading Matformer config from {:?}", matformer_path);
302 let config = Arc::new(MatformerConfig::from_file(matformer_path)?);
303
304 if let Some(slice_name) = &self.config.matformer_slice_name {
305 info!("Using Matformer slice: {}", slice_name);
306 Some(MatformerSliceConfig::new(slice_name.clone(), config))
307 } else {
308 warn!("Matformer config loaded but no slice name specified. Models will use their default slice.");
311 None
312 }
313 } else {
314 None
315 };
316
317 if use_nccl {
319 mapper = DeviceMapSetting::DummyNccl {
320 nm_device: available_devices[0].clone(),
321 };
322 } else if let DeviceMapSetting::Auto(mut params) = mapper.clone() {
323 params = params.maybe_promote_to_vision();
325
326 let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
328
329 if QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)? != 1 {
331 in_situ_quant = None;
332 }
333
334 let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
337 if let Some(serialized) = &*self.from_uqff.read().unwrap() {
338 let weight_pack_factor = {
339 let ser_artifacts = unsafe {
340 candle_core::safetensors::MmapedSafetensors::multi(serialized)?
341 };
342 let mut total_pack_factors = 0;
343 let total_tensors = ser_artifacts.tensors().len();
344 for (_, artifact) in ser_artifacts.tensors() {
345 let artifact = artifact.data();
346 let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
348 let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
349 {
350 QuantizedSerdeType::Hqq => {
351 HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
352 .pack_factor(dtype)
353 }
354 QuantizedSerdeType::Gguf => {
355 GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
356 .pack_factor(dtype)
357 }
358 QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
359 QuantizedSerdeType::Unquant => 1,
360 QuantizedSerdeType::Afq => {
361 AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
362 .pack_factor(dtype)
363 }
364 };
365 total_pack_factors += pack_factor;
366 }
367
368 total_pack_factors / total_tensors
369 };
370
371 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
372 &config,
373 dtype,
374 weight_pack_factor,
375 matformer_slicing_config.as_ref(),
376 )?;
377 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
378 &config,
379 dtype,
380 weight_pack_factor,
381 matformer_slicing_config.as_ref(),
382 )?;
383 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
384 (
385 layer_sizes_in_bytes,
386 non_mapped_size_in_bytes,
387 layer_sizes_sum + non_mapped_size_in_bytes,
388 )
389 } else if let Some(isq) = in_situ_quant {
390 let weight_pack_factor = isq.pack_factor(dtype);
391 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
392 &config,
393 dtype,
394 weight_pack_factor,
395 matformer_slicing_config.as_ref(),
396 )?;
397 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
398 &config,
399 dtype,
400 weight_pack_factor,
401 matformer_slicing_config.as_ref(),
402 )?;
403 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
404 (
405 layer_sizes_in_bytes,
406 non_mapped_size_in_bytes,
407 layer_sizes_sum + non_mapped_size_in_bytes,
408 )
409 } else {
410 let weight_pack_factor =
412 QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)?;
413 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
414 &config,
415 dtype,
416 weight_pack_factor,
417 matformer_slicing_config.as_ref(),
418 )?;
419 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
420 &config,
421 dtype,
422 weight_pack_factor,
423 matformer_slicing_config.as_ref(),
424 )?;
425 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
426 (
427 layer_sizes_in_bytes,
428 non_mapped_size_in_bytes,
429 layer_sizes_sum + non_mapped_size_in_bytes,
430 )
431 };
432
433 let new = auto_device_map::get_device_layers(
434 &*self.inner,
435 &config,
436 self.inner.num_layers(&config)?,
437 layer_sizes_in_bytes,
438 non_mapped_size_in_bytes,
439 total_model_size_in_bytes,
440 &available_devices,
441 dtype,
442 ¶ms,
443 paged_attn_config.as_ref(),
444 )?;
445 mapper = DeviceMapSetting::Map(new);
446 }
447
448 let pipeline_mapper = mapper.into_mapper(
449 self.inner.num_layers(&config)?,
450 &device,
451 self.config.topology.as_ref(),
452 )?;
453 let mapper = mapper.into_mapper(
454 self.inner.num_layers(&config)?,
455 &device,
456 self.config.topology.as_ref(),
457 )?;
458 let mut layer_devices = Vec::new();
459 for layer in 0..self.inner.num_layers(&config)? {
460 let device = mapper.device_for(layer, false).cloned();
461 layer_devices.push(device);
462 }
463 let dtype = mapper.get_min_dtype(dtype)?;
464
465 let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
468 if mapping_uses_cpu && paged_attn_config.is_some() {
469 warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
470 paged_attn_config = None;
471 }
472
473 info!("Model config: {:?}", self.inner.get_config_repr(&config)?);
474 if crate::using_flash_attn() {
475 once_log_info("FlashAttention is enabled.");
476 }
477
478 let topology_overrides = self
479 .config
480 .topology
481 .as_ref()
482 .map(|topology| {
483 topology
484 .pattern_overrides()
485 .into_iter()
486 .map(|(regex, layer)| ImmediateIsqOverride {
487 predicate: regex,
488 ty: layer.isq,
489 device: layer.device.clone(),
490 })
491 .collect::<Vec<_>>()
492 })
493 .unwrap_or_default();
494 let has_override_isq = topology_overrides
495 .iter()
496 .any(|override_entry| override_entry.ty.is_some());
497 let topology_requires_post_quant = self
498 .config
499 .topology
500 .as_ref()
501 .is_some_and(|topology| topology.requires_post_quantization());
502
503 let allow_immediate_cli = self.config.imatrix.is_none()
504 && self.config.calibration_file.is_none()
505 && !device.is_cuda()
506 && in_situ_quant.is_some();
507
508 let mut immediate_ty = None;
509 let mut immediate_predicates = Vec::new();
510 if allow_immediate_cli {
511 immediate_ty = in_situ_quant;
512 immediate_predicates = self.inner.immediate_isq_predicates(&config)?;
513 info!("Applying ISQ to {in_situ_quant:?}");
514 if immediate_predicates.is_empty() {
515 warn!("No predicates for this model and ISQ setting detected. ISQ will not be applied to any weights!");
516 }
517 }
518
519 let use_immediate = allow_immediate_cli || has_override_isq;
520 if use_immediate {
521 mistralrs_quant::set_immediate_isq_with_overrides(
522 immediate_ty,
523 immediate_predicates.clone(),
524 topology_overrides.clone(),
525 );
526 }
527
528 let mut loading_isq = if use_immediate {
530 false
531 } else {
532 in_situ_quant.is_some()
533 };
534 if self.config.imatrix.is_some() || self.config.calibration_file.is_some() {
535 loading_isq = true;
536 }
537 loading_isq |= topology_requires_post_quant;
538
539 if self.config.imatrix.is_some() && self.config.calibration_file.is_some() {
540 anyhow::bail!(
541 "`imatrix` and `calibration_file` were both specified, this is not allowed."
542 );
543 }
544
545 let load_device = if !loading_isq || self.config.calibration_file.is_some() {
547 loading_isq = false;
548 device.clone()
549 } else {
550 Device::Cpu
551 };
552
553 let attention_mechanism = if paged_attn_config.is_some() {
554 AttentionImplementation::PagedAttention
555 } else {
556 AttentionImplementation::Eager
557 };
558
559 let multi_progress = Arc::new(new_multi_progress());
560
561 let mut model = if use_nccl {
562 let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
563 dtype,
564 &device,
565 &available_devices,
566 silent,
567 &config,
568 loading_isq,
569 self.config.from_uqff.is_some(),
570 IsqOrganization::Default,
571 &*self.inner,
572 paths.as_ref(),
573 )?;
574
575 match self.kind {
577 ModelKind::Normal => vision_normal_model_loader_sharded!(
578 sharded_vb,
579 config,
580 self.inner,
581 mapper,
582 loading_isq,
583 device.clone(),
584 attention_mechanism,
585 multi_progress.clone(),
586 matformer_slicing_config.clone(),
587 ),
588 _ => unreachable!(),
589 }
590 } else {
591 match self.kind {
592 ModelKind::Normal => vision_normal_model_loader!(
593 paths,
594 Some(dtype),
595 &load_device,
596 layer_devices.clone(),
597 config,
598 self.inner,
599 silent,
600 mapper,
601 loading_isq,
602 self.config.from_uqff.is_some(),
603 device.clone(),
604 attention_mechanism,
605 multi_progress,
606 matformer_slicing_config.clone(),
607 ),
608 _ => unreachable!(),
609 }
610 };
611
612 let preprocessor_config: PreProcessorConfig = match paths.get_preprocessor_config().as_ref()
614 {
615 Some(preprocessor_config) => {
616 serde_json::from_str(&fs::read_to_string(preprocessor_config).unwrap()).unwrap()
617 }
618 None => PreProcessorConfig::default(),
619 };
620 let processor_config: Option<ProcessorConfig> = paths
621 .get_processor_config()
622 .as_ref()
623 .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
624
625 let processor = self.inner.get_processor(
626 &config,
627 processor_config,
628 preprocessor_config.clone(),
629 self.config.max_edge,
630 ); let tokenizer = get_tokenizer(
633 paths.get_tokenizer_filename(),
634 Some(processor.get_special_tokens()),
635 )?;
636
637 let gen_conf: Option<GenerationConfig> = paths
638 .get_gen_conf_filename()
639 .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
640 let chat_template_explicit = paths
641 .get_chat_template_explicit()
642 .as_ref()
643 .map(|x| x.to_string_lossy().to_string());
644 let chat_template = get_chat_template(
645 paths,
646 self.jinja_explicit.as_ref(),
647 chat_template_explicit.as_ref(),
648 self.chat_template.as_ref(),
649 None,
650 );
651
652 if let Some(calibration_file) = &self.config.calibration_file {
653 let calibration_data = std::fs::read_to_string(calibration_file)?;
654 let tokens = tokenizer
656 .encode_fast(calibration_data, false)
657 .map_err(anyhow::Error::msg)?
658 .get_ids()
659 .to_vec();
660 info!(
661 "Collecting imatrix from calibration file `{}` of {} tokens.",
662 calibration_file.display(),
663 tokens.len()
664 );
665 let bos_toks = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();
666 let bos_tok_id = tokenizer
667 .token_to_id(&bos_toks[0])
668 .expect("Somehow the bos token is not present.");
669
670 model.begin_track_stats()?;
673
674 const CHUNK_SIZE: usize = 1024;
675 let n_chunks: usize = tokens.len().div_ceil(CHUNK_SIZE);
676 let start = Instant::now();
677 for (i, chunk) in tokens.chunks(CHUNK_SIZE).enumerate() {
678 let chunk = [vec![bos_tok_id], chunk.to_vec()].concat();
679 let chunk_len = chunk.len();
680
681 let start = Instant::now();
682 let inputs = make_prompt_chunk(
683 0,
684 vec![&chunk],
685 &[0],
686 &load_device,
687 None,
688 false,
689 None,
690 None,
691 )?;
692 let _ = model.forward(
693 &inputs.input,
694 None, &inputs.positions,
696 inputs.context_lens,
697 inputs.position_ids,
698 model.default_model_specific_args(&inputs.input),
699 None,
700 &inputs.flash_meta,
701 )?;
702 match model.cache_mut() {
703 EitherCache::Full(full) => {
704 for layer in &mut *full.lock() {
705 *layer = None
706 }
707 }
708 EitherCache::Normal(normal) => {
709 for layer in &mut *normal.lock().unwrap().0 {
710 layer.reset();
711 }
712 }
713 EitherCache::Hybrid(hybrid) => {
714 hybrid.lock().unwrap().reset();
715 }
716 }
717 let end = Instant::now();
718 info!(
719 "Processed chunk {}/{n_chunks} ({chunk_len} tokens), {:.2}s",
720 i + 1,
721 end.duration_since(start).as_secs_f32()
722 );
723 }
724 load_device.synchronize()?;
725 let end = Instant::now();
726 info!(
727 "Finished collecting imatrix in {:.2}s",
728 end.duration_since(start).as_secs_f32()
729 );
730 }
731
732 let should_serialize = self.config.write_uqff.is_some();
733 let should_quantize_pass = loading_isq;
734
735 if (should_quantize_pass || should_serialize) && self.config.from_uqff.is_none() {
736 let imatrix_source = if should_quantize_pass {
737 match (
738 self.config.imatrix.as_ref(),
739 self.config.calibration_file.is_some(),
740 ) {
741 (None, false) => None,
742 (Some(file), false) => Some(ImatrixDataSource::File(file)),
743 (None, true) => Some(ImatrixDataSource::Collected),
744 (Some(_), true) => unreachable!(),
745 }
746 } else {
747 None
748 };
749 if should_quantize_pass {
750 info!("Applying ISQ to all ranks.");
751 } else {
752 info!("Serializing existing ISQ tensors without additional quantization.");
753 }
754 model.quantize(
755 in_situ_quant,
756 device.clone(),
757 self.config.topology.as_ref(),
758 silent,
759 imatrix_source,
760 IsqOrganization::Default,
761 should_quantize_pass,
762 self.config.write_uqff.as_ref(),
763 UqffFullSer {
764 tokenizer: &tokenizer,
765 template_filename: paths.get_template_filename(),
766 generation_config: paths.get_gen_conf_filename(),
767 config: config.clone(),
768 processor_filename: paths.get_processor_config(),
769 preprocessor_filename: paths.get_preprocessor_config(),
770 modules: None,
771 module_paths: None,
772 },
773 Arc::new(new_multi_progress()),
774 )?;
775 } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
776 model.load_from_artifacts(
777 device.clone(),
778 self.config.topology.as_ref(),
779 silent,
780 from_uqff,
781 )?;
782 }
783
784 let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
785 anyhow::ensure!(
786 !matches!(self.kind, ModelKind::Adapter { .. }),
787 "PagedAttention does not support adapter models."
788 );
789 let cache_config = calculate_cache_config(
790 paged_attn_config.mem_gpu,
791 paged_attn_config.block_size,
792 dtype,
793 paged_attn_config.cache_type,
794 model.config(),
795 &device,
796 &layer_devices,
797 silent,
798 )?;
799 let cache_engine =
800 CacheEngine::new(model.config(), &cache_config, dtype, &device, layer_devices)?;
801 (Some(cache_config), Some(cache_engine))
802 } else {
803 (None, None)
804 };
805
806 let max_seq_len = model.max_seq_len();
807 let llg_factory = build_llg_factory(tokenizer.clone())?;
808 let num_hidden_layers = match model.cache() {
809 EitherCache::Full(full) => full.lock().len(),
810 EitherCache::Normal(normal) => normal.lock().unwrap().0.len(),
811 EitherCache::Hybrid(hybrid) => hybrid.lock().unwrap().num_layers(),
812 };
813 let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
814 let sliding_window = model.config().sliding_window;
815 let model_metadata = Arc::new(model.config().clone());
816 Ok(Arc::new(Mutex::new(VisionPipeline {
817 model,
818 tokenizer: tokenizer.into(),
819 chat_template: Arc::new(chat_template),
820 model_id: self.model_id.clone(),
821 metadata: Arc::new(GeneralMetadata {
822 max_seq_len,
823 llg_factory: Some(llg_factory),
824 is_xlora: false,
825 num_hidden_layers,
826 eos_tok: eos,
827 kind: self.kind.clone(),
828 no_kv_cache: false,
829 no_prefix_cache: !self.inner.supports_prefix_cacher(&config),
830 activation_dtype: dtype,
831 sliding_window,
832 cache_config,
833 cache_engine,
834 model_metadata: Some(model_metadata),
835 modalities: self.inner.modalities(&config)?,
836 }),
837 processor,
838 prefixer: self.inner.prefixer(&config),
839 preprocessor_config: Arc::new(preprocessor_config),
840 topology: self.config.topology.clone(),
841 silent,
842 template_filename: paths.get_template_filename().clone(),
843 generation_config: paths.get_gen_conf_filename().cloned(),
844 config,
845 processor_filename: paths.get_processor_config().clone(),
846 preprocessor_filename: paths.get_preprocessor_config().clone(),
847 mapper: pipeline_mapper,
848 imatrix: self.config.imatrix.clone(),
849 })))
850 }
851
852 fn get_id(&self) -> String {
853 self.model_id.to_string()
854 }
855
856 fn get_kind(&self) -> ModelKind {
857 self.kind.clone()
858 }
859}
860
861impl PreProcessingMixin for VisionPipeline {
862 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
863 Some(self.chat_template.clone())
864 }
865 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
866 Some(self.preprocessor_config.clone())
867 }
868 fn get_processor(&self) -> Arc<dyn super::Processor> {
869 self.processor.clone()
870 }
871}
872
873impl IsqPipelineMixin for VisionPipeline {
874 fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
875 let device = self.device().clone();
876 self.model
877 .quantize(
878 Some(dtype),
879 device,
880 self.topology.as_ref(),
881 self.silent,
882 self.imatrix.as_ref().map(ImatrixDataSource::File),
883 IsqOrganization::Default,
884 true,
885 None,
886 UqffFullSer {
887 tokenizer: &self.tokenizer,
888 template_filename: &self.template_filename,
889 generation_config: self.generation_config.as_ref(),
890 config: self.config.clone(),
891 processor_filename: &self.processor_filename,
892 preprocessor_filename: &self.preprocessor_filename,
893 modules: None,
894 module_paths: None,
895 },
896 Arc::new(new_multi_progress()),
897 )
898 .map_err(anyhow::Error::msg)
899 }
900}
901
902impl CacheManagerMixin for VisionPipeline {
903 fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
904 if matches!(self.model.cache(), EitherCache::Full(_)) {
905 FullCacheManager.clone_in_cache(self, seqs, false)
906 } else {
907 NormalCacheManager.clone_in_cache(self, seqs, false)
908 }
909 }
910 fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
911 if matches!(self.model.cache(), EitherCache::Full(_)) {
912 FullCacheManager.clone_out_cache(self, seqs, false)
913 } else {
914 NormalCacheManager.clone_out_cache(self, seqs, false)
915 }
916 }
917 fn set_none_cache(
918 &self,
919 seqs: &mut [&mut Sequence],
920 reset_non_granular: bool,
921 modify_draft_cache: bool,
922
923 load_preallocated_cache: bool,
924 ) {
925 if matches!(self.model.cache(), EitherCache::Full(_)) {
926 FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false);
927 } else {
928 NormalCacheManager.set_none_cache(
929 self,
930 seqs,
931 modify_draft_cache,
932 load_preallocated_cache,
933 );
934 }
935 if reset_non_granular {
936 self.reset_non_granular_state()
937 }
938 }
939 fn cache(&self) -> &EitherCache {
940 self.model.cache()
941 }
942}
943
944impl MetadataMixin for VisionPipeline {
945 fn device(&self) -> Device {
946 self.model.device().clone()
947 }
948 fn get_metadata(&self) -> Arc<GeneralMetadata> {
949 self.metadata.clone()
950 }
951 fn name(&self) -> String {
952 self.model_id.clone()
953 }
954 fn reset_non_granular_state(&self) {}
955 fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
956 Some(self.tokenizer.clone())
957 }
958 fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
959 Some(&*self.mapper)
960 }
961}
962
963#[async_trait::async_trait]
964impl Pipeline for VisionPipeline {
965 fn forward_inputs(
966 &mut self,
967 inputs: Box<dyn Any>,
968 return_raw_logits: bool,
969 ) -> candle_core::Result<ForwardInputsResult> {
970 let ModelInputs {
971 input_ids,
972 seqlen_offsets,
973 context_lens,
974 position_ids,
975 pixel_values,
976 model_specific_args,
977 paged_attn_meta,
978 flash_meta,
979 } = *inputs.downcast::<ModelInputs>().expect("Downcast failed.");
980 let metadata = self.get_metadata();
981 let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
982 (Some(engine), Some(meta)) => Some((engine.get_kv_cache().clone(), meta)),
983 (Some(_), None) => {
984 candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
986 }
987 (None, Some(_)) => {
988 candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
990 }
991 (None, None) => None,
992 };
993 let logits = self.model.forward(
994 &input_ids,
995 pixel_values,
996 &seqlen_offsets,
997 context_lens,
998 position_ids,
999 model_specific_args,
1000 paged_attn_meta,
1001 &flash_meta,
1002 )?;
1003 if return_raw_logits {
1004 Ok(ForwardInputsResult::RawLogits { logits })
1005 } else {
1006 Ok(ForwardInputsResult::CausalGeneration { logits })
1007 }
1008 }
1009 async fn sample_causal_gen(
1010 &self,
1011 seqs: &mut [&mut Sequence],
1012 logits: Vec<Tensor>,
1013 prefix_cacher: &mut PrefixCacheManagerV2,
1014 disable_eos_stop: bool,
1015 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
1016 ) -> Result<(), candle_core::Error> {
1017 sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
1018 }
1019 fn category(&self) -> ModelCategory {
1020 ModelCategory::Vision {
1021 prefixer: self.prefixer.clone(),
1022 }
1023 }
1024}
1025
1026impl AnyMoePipelineMixin for VisionPipeline {
1027 fn amoe_finish_training(&mut self, gate_model_id: Option<String>) -> candle_core::Result<()> {
1028 self.model.finish_training(gate_model_id)
1029 }
1030 fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
1031 self.model.get_vars()
1032 }
1033 fn amoe_base_model_trainable_params(&self) -> usize {
1034 self.model.trainable_params()
1035 }
1036 fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
1037 self.model.take_cached_gating_outputs()
1038 }
1039 fn amoe_create_layers(
1040 &mut self,
1041 model_ids: Vec<String>,
1042 token: &TokenSource,
1043 revision: Option<String>,
1044 match_regex: &str,
1045 config: crate::amoe::AnyMoeConfig,
1046 dtype: candle_core::DType,
1047 dev: &Device,
1048 (prefix, mlp): (String, String),
1049 layers: Vec<usize>,
1050 expert_type: AnyMoeExpertType,
1051 silent: bool,
1052 gate_model_id: Option<String>,
1053 ) -> candle_core::Result<()> {
1054 let mut vbs = Vec::new();
1055 let regex = Regex::new(match_regex).map_err(candle_core::Error::msg)?;
1057 for model_id in model_ids {
1058 let model_id_str = &model_id;
1059 let model_id = Path::new(&model_id);
1060
1061 let api = {
1062 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1063 let mut api = ApiBuilder::from_cache(cache)
1064 .with_progress(!silent)
1065 .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1066 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1067 api = api.with_cache_dir(x.into());
1068 }
1069 api.build().map_err(candle_core::Error::msg)?
1070 };
1071 let revision = revision.clone().unwrap_or("main".to_string());
1072 let api = api.repo(Repo::with_revision(
1073 model_id_str.clone(),
1074 RepoType::Model,
1075 revision.clone(),
1076 ));
1077
1078 let mut filenames = vec![];
1079 for rfilename in
1080 api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1081 {
1082 filenames.push(api_get_file!(api, &rfilename, model_id));
1083 }
1084
1085 let regex = regex.clone();
1086 let match_regex_clone = match_regex.to_string();
1087 let layers_clone = layers.clone();
1088 let vb = from_mmaped_safetensors(
1089 filenames,
1090 vec![],
1091 Some(dtype),
1092 dev,
1093 vec![None],
1094 silent,
1095 None,
1096 move |key| {
1097 if regex.is_match(&key) {
1098 let last_layer_idx = key.find(&match_regex_clone).unwrap() - 1;
1101 let first_layer_idx = key[..last_layer_idx].rfind('.').unwrap();
1102 let layer_n = key[first_layer_idx + 1..last_layer_idx]
1103 .parse::<usize>()
1104 .unwrap();
1105 layers_clone.contains(&layer_n) || layers_clone.is_empty()
1106 } else {
1107 false
1108 }
1109 },
1110 Arc::new(|_| DeviceForLoadTensor::Base),
1111 )?;
1112 vbs.push(vb);
1113 }
1114
1115 let gate_vb = if let Some(gate_model_id) = gate_model_id {
1116 let model_id_str = &gate_model_id;
1117 let model_id = Path::new(&gate_model_id);
1118
1119 let api = {
1120 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1121 let mut api = ApiBuilder::from_cache(cache)
1122 .with_progress(!silent)
1123 .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1124 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1125 api = api.with_cache_dir(x.into());
1126 }
1127 api.build().map_err(candle_core::Error::msg)?
1128 };
1129 let revision = revision.clone().unwrap_or("main".to_string());
1130 let api = api.repo(Repo::with_revision(
1131 model_id_str.clone(),
1132 RepoType::Model,
1133 revision.clone(),
1134 ));
1135
1136 let mut gate_filenames = vec![];
1137 for rfilename in
1138 api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1139 {
1140 gate_filenames.push(api_get_file!(api, &rfilename, model_id));
1141 }
1142 assert_eq!(
1143 gate_filenames.len(),
1144 1,
1145 "Gate model ID must contain only one .safetensors file"
1146 );
1147
1148 let vb = from_mmaped_safetensors(
1149 gate_filenames.clone(),
1150 vec![],
1151 Some(dtype),
1152 dev,
1153 vec![None],
1154 silent,
1155 None,
1156 |_| true,
1157 Arc::new(|_| DeviceForLoadTensor::Base),
1158 )?;
1159 info!(
1160 "Loaded gating layers from `{}`",
1161 gate_filenames[0].display()
1162 );
1163 Some(vb)
1164 } else {
1165 None
1166 };
1167
1168 self.model
1169 .create_anymoe_layers(vbs, config, (prefix, mlp), layers, expert_type, gate_vb)
1170 }
1171 fn amoe_supported(&self) -> bool {
1172 self.model.amoe_supported()
1173 }
1174}