1use super::isq::ImatrixDataSource;
2use super::isq::UqffFullSer;
3use super::{
4 get_model_paths, get_xlora_paths, AdapterKind, AnyMoePipelineMixin, AutoVisionLoader,
5 CacheManager, CacheManagerMixin, EitherCache, ForwardInputsResult, Gemma3Loader,
6 GeneralMetadata, IsqPipelineMixin, Loader, MetadataMixin, MiniCpmOLoader, ModelCategory,
7 ModelKind, ModelPaths, Phi4MMLoader, PreProcessingMixin, Processor, Qwen2VLLoader, TokenSource,
8 VLlama4Loader, VLlamaLoader, VisionModel, VisionModelLoader, VisionPromptPrefixer,
9};
10use super::{
11 Idefics2Loader, Idefics3Loader, LLaVALoader, LLaVANextLoader, Mistral3Loader, Phi3VLoader,
12 Qwen2_5VLLoader, VisionLoaderType,
13};
14use crate::device_map::{self, DeviceMapper};
15use crate::distributed::{self, WorkerTransferData};
16use crate::kv_cache::{FullCacheManager, NormalCacheManager};
17use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine};
18use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
19use crate::pipeline::llg::build_llg_factory;
20use crate::pipeline::loaders::QuantizationConfigShim;
21use crate::pipeline::sampling::sample_and_add_toks;
22use crate::pipeline::text_models_inputs_processor::make_prompt_chunk;
23use crate::pipeline::{get_chat_template, ChatTemplate, IsqOrganization, LocalModelPaths};
24use crate::prefix_cacher::PrefixCacheManagerV2;
25use crate::sequence::Sequence;
26use crate::utils::tokenizer::get_tokenizer;
27use crate::utils::varbuilder_utils::DeviceForLoadTensor;
28use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
29use crate::vision_models::preprocessor_config::PreProcessorConfig;
30use crate::vision_models::processor_config::ProcessorConfig;
31use crate::vision_models::ModelInputs;
32use crate::{
33 api_dir_list, api_get_file, get_paths, get_uqff_paths, vision_normal_model_loader,
34 vision_normal_model_loader_sharded, AnyMoeExpertType, DeviceMapSetting, Ordering,
35 PagedAttentionConfig, Pipeline, Topology, TryIntoDType, GLOBAL_HF_CACHE,
36};
37use anyhow::Result;
38use candle_core::{Device, Tensor, Var};
39use hf_hub::Cache;
40use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
41use indicatif::MultiProgress;
42use mistralrs_quant::log::once_log_info;
43use mistralrs_quant::{AfqLayer, GgufMatMul, HqqLayer, IsqType, QuantizedSerdeType};
44use rand_isaac::Isaac64Rng;
45use regex_automata::meta::Regex;
46use std::any::Any;
47use std::borrow::Cow;
48use std::num::NonZeroUsize;
49use std::path::{Path, PathBuf};
50use std::str::FromStr;
51use std::sync::{Arc, RwLock};
52use std::time::Instant;
53use std::{env, fs};
54use tokenizers::Tokenizer;
55use tokio::sync::Mutex;
56use tracing::{info, warn};
57
58pub struct VisionPipeline {
59 model: Box<dyn VisionModel + Send + Sync>,
60 tokenizer: Arc<Tokenizer>,
61 chat_template: Arc<ChatTemplate>,
62 model_id: String,
63 metadata: Arc<GeneralMetadata>,
64 processor: Arc<dyn Processor + Send + Sync>,
65 preprocessor_config: Arc<PreProcessorConfig>,
66 topology: Option<Topology>,
67 silent: bool,
68 prefixer: Arc<dyn VisionPromptPrefixer>,
69 mapper: Box<dyn DeviceMapper + Send + Sync>,
70
71 template_filename: Option<PathBuf>,
73 generation_config: Option<PathBuf>,
74 config: String,
75 processor_filename: Option<PathBuf>,
76 preprocessor_filename: Option<PathBuf>,
77 imatrix: Option<PathBuf>,
78}
79
80pub struct VisionLoader {
82 inner: Box<dyn VisionModelLoader>,
83 model_id: String,
84 config: VisionSpecificConfig,
85 kind: ModelKind,
86 chat_template: Option<String>,
87 tokenizer_json: Option<String>,
88 xlora_model_id: Option<String>,
89 xlora_order: Option<Ordering>,
90 token_source: RwLock<Option<TokenSource>>,
91 revision: RwLock<Option<String>>,
92 from_uqff: RwLock<Option<Vec<PathBuf>>>,
93 jinja_explicit: Option<String>,
94 hf_cache_path: Option<PathBuf>,
95 lora_adapter_ids: Option<Vec<String>>,
96}
97
98#[derive(Default)]
99pub struct VisionLoaderBuilder {
101 model_id: Option<String>,
102 config: VisionSpecificConfig,
103 kind: ModelKind,
104 chat_template: Option<String>,
105 tokenizer_json: Option<String>,
106 jinja_explicit: Option<String>,
107 hf_cache_path: Option<PathBuf>,
108 lora_adapter_ids: Option<Vec<String>>,
109}
110
111#[derive(Clone, Default)]
112pub struct VisionSpecificConfig {
114 pub prompt_chunksize: Option<NonZeroUsize>,
115 pub topology: Option<Topology>,
116 pub write_uqff: Option<PathBuf>,
117 pub from_uqff: Option<Vec<PathBuf>>,
118 pub max_edge: Option<u32>,
119 pub imatrix: Option<PathBuf>,
120 pub calibration_file: Option<PathBuf>,
121 pub hf_cache_path: Option<PathBuf>,
122}
123
124impl VisionLoaderBuilder {
125 pub fn new(
126 config: VisionSpecificConfig,
127 chat_template: Option<String>,
128 tokenizer_json: Option<String>,
129 model_id: Option<String>,
130 jinja_explicit: Option<String>,
131 ) -> Self {
132 Self {
133 config,
134 chat_template,
135 tokenizer_json,
136 model_id,
137 jinja_explicit,
138 kind: ModelKind::Normal,
139 hf_cache_path: None,
140 ..Default::default()
141 }
142 }
143
144 pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
145 self.hf_cache_path = Some(hf_cache_path);
146 self
147 }
148
149 pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
150 self.kind = ModelKind::Adapter {
151 adapter: AdapterKind::Lora,
152 };
153 self.lora_adapter_ids = Some(lora_adapter_ids);
154 self
155 }
156
157 pub fn build(self, loader: Option<VisionLoaderType>) -> Box<dyn Loader> {
158 let loader: Box<dyn VisionModelLoader> = match loader {
159 Some(VisionLoaderType::Phi3V) => Box::new(Phi3VLoader),
160 Some(VisionLoaderType::Idefics2) => Box::new(Idefics2Loader),
161 Some(VisionLoaderType::LLaVANext) => Box::new(LLaVANextLoader),
162 Some(VisionLoaderType::LLaVA) => Box::new(LLaVALoader),
163 Some(VisionLoaderType::VLlama) => Box::new(VLlamaLoader),
164 Some(VisionLoaderType::Qwen2VL) => Box::new(Qwen2VLLoader),
165 Some(VisionLoaderType::Idefics3) => Box::new(Idefics3Loader),
166 Some(VisionLoaderType::MiniCpmO) => Box::new(MiniCpmOLoader),
167 Some(VisionLoaderType::Phi4MM) => Box::new(Phi4MMLoader),
168 Some(VisionLoaderType::Qwen2_5VL) => Box::new(Qwen2_5VLLoader),
169 Some(VisionLoaderType::Gemma3) => Box::new(Gemma3Loader),
170 Some(VisionLoaderType::Mistral3) => Box::new(Mistral3Loader),
171 Some(VisionLoaderType::Llama4) => Box::new(VLlama4Loader),
172 None => Box::new(AutoVisionLoader),
173 };
174 Box::new(VisionLoader {
175 inner: loader,
176 model_id: self.model_id.unwrap(),
177 config: self.config,
178 kind: self.kind,
179 chat_template: self.chat_template,
180 tokenizer_json: self.tokenizer_json,
181 xlora_model_id: None,
182 xlora_order: None,
183 jinja_explicit: self.jinja_explicit,
184 token_source: RwLock::new(None),
185 revision: RwLock::new(None),
186 from_uqff: RwLock::new(None),
187 hf_cache_path: self.hf_cache_path,
188 lora_adapter_ids: self.lora_adapter_ids,
189 })
190 }
191}
192
193impl Loader for VisionLoader {
194 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
195 fn load_model_from_hf(
196 &self,
197 revision: Option<String>,
198 token_source: TokenSource,
199 dtype: &dyn TryIntoDType,
200 device: &Device,
201 silent: bool,
202 mapper: DeviceMapSetting,
203 in_situ_quant: Option<IsqType>,
204 paged_attn_config: Option<PagedAttentionConfig>,
205 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
206 let cache = self
207 .hf_cache_path
208 .clone()
209 .map(Cache::new)
210 .unwrap_or_default();
211 GLOBAL_HF_CACHE.get_or_init(|| cache);
212
213 let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths!(
214 LocalModelPaths,
215 &token_source,
216 revision.clone(),
217 self,
218 None,
219 None,
220 silent,
221 self.config.from_uqff.is_some()
222 );
223 if let Some(from_uqff) = self.config.from_uqff.clone() {
224 *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
225 }
226 *self
227 .token_source
228 .write()
229 .expect("Failed to write to token source") = Some(token_source);
230 *self.revision.write().expect("Failed to write to revision") = revision;
231 self.load_model_from_path(
232 &paths?,
233 dtype,
234 device,
235 silent,
236 mapper,
237 in_situ_quant,
238 paged_attn_config,
239 )
240 }
241
242 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
243 fn load_model_from_path(
244 &self,
245 paths: &Box<dyn ModelPaths>,
246 dtype: &dyn TryIntoDType,
247 device: &Device,
248 silent: bool,
249 mut mapper: DeviceMapSetting,
250 mut in_situ_quant: Option<IsqType>,
251 mut paged_attn_config: Option<PagedAttentionConfig>,
252 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
253 let config = std::fs::read_to_string(paths.get_config_filename())?;
254
255 if !self.inner.supports_paged_attention(&config) {
256 paged_attn_config = None;
257 }
258
259 let use_nccl = mistralrs_quant::distributed::use_nccl();
260
261 let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
262 let payload: WorkerTransferData = serde_json::from_str(&payload)?;
263 let WorkerTransferData::Init { id: _, worker_rank } = payload;
264 vec![candle_core::Device::new_cuda_with_stream(worker_rank + 1)?]
265 } else if use_nccl {
266 vec![candle_core::Device::new_cuda_with_stream(0)?]
267 } else {
268 device_map::get_all_similar_devices(device)?
269 };
270 let device = if use_nccl {
271 available_devices[0].clone()
272 } else {
273 device.clone()
274 };
275
276 if use_nccl {
278 mapper = DeviceMapSetting::DummyNccl {
279 nm_device: available_devices[0].clone(),
280 };
281 } else if let DeviceMapSetting::Auto(params) = mapper.clone() {
282 let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
284
285 if QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)? != 1 {
287 in_situ_quant = None;
288 }
289
290 let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
293 if let Some(serialized) = &*self.from_uqff.read().unwrap() {
294 let weight_pack_factor = {
295 let ser_artifacts = unsafe {
296 candle_core::safetensors::MmapedSafetensors::multi(serialized)?
297 };
298 let mut total_pack_factors = 0;
299 let total_tensors = ser_artifacts.tensors().len();
300 for (_, artifact) in ser_artifacts.tensors() {
301 let artifact = artifact.data();
302 let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
304 let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
305 {
306 QuantizedSerdeType::Hqq => {
307 HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
308 .pack_factor(dtype)
309 }
310 QuantizedSerdeType::Gguf => {
311 GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
312 .pack_factor(dtype)
313 }
314 QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
315 QuantizedSerdeType::Unquant => 1,
316 QuantizedSerdeType::Afq => {
317 AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
318 .pack_factor(dtype)
319 }
320 };
321 total_pack_factors += pack_factor;
322 }
323
324 total_pack_factors / total_tensors
325 };
326
327 let layer_sizes_in_bytes =
328 self.inner
329 .layer_sizes_in_bytes(&config, dtype, weight_pack_factor)?;
330 let non_mapped_size_in_bytes =
331 self.inner
332 .non_mapped_size_in_bytes(&config, dtype, weight_pack_factor)?;
333 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
334 (
335 layer_sizes_in_bytes,
336 non_mapped_size_in_bytes,
337 layer_sizes_sum + non_mapped_size_in_bytes,
338 )
339 } else if let Some(isq) = in_situ_quant {
340 let weight_pack_factor = isq.pack_factor(dtype);
341 let layer_sizes_in_bytes =
342 self.inner
343 .layer_sizes_in_bytes(&config, dtype, weight_pack_factor)?;
344 let non_mapped_size_in_bytes =
345 self.inner
346 .non_mapped_size_in_bytes(&config, dtype, weight_pack_factor)?;
347 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
348 (
349 layer_sizes_in_bytes,
350 non_mapped_size_in_bytes,
351 layer_sizes_sum + non_mapped_size_in_bytes,
352 )
353 } else {
354 let weight_pack_factor =
356 QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)?;
357 let layer_sizes_in_bytes =
358 self.inner
359 .layer_sizes_in_bytes(&config, dtype, weight_pack_factor)?;
360 let non_mapped_size_in_bytes =
361 self.inner
362 .non_mapped_size_in_bytes(&config, dtype, weight_pack_factor)?;
363 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
364 (
365 layer_sizes_in_bytes,
366 non_mapped_size_in_bytes,
367 layer_sizes_sum + non_mapped_size_in_bytes,
368 )
369 };
370
371 let new = self.inner.get_device_layers(
372 &config,
373 self.inner.num_layers(&config)?,
374 layer_sizes_in_bytes,
375 non_mapped_size_in_bytes,
376 total_model_size_in_bytes,
377 &available_devices,
378 dtype,
379 ¶ms,
380 params.max_seq_len(),
381 paged_attn_config.as_ref(),
382 )?;
383 mapper = DeviceMapSetting::Map(new);
384 }
385
386 let pipeline_mapper = mapper.into_mapper(
387 self.inner.num_layers(&config)?,
388 &device,
389 self.config.topology.as_ref(),
390 )?;
391 let mapper = mapper.into_mapper(
392 self.inner.num_layers(&config)?,
393 &device,
394 self.config.topology.as_ref(),
395 )?;
396 let mut layer_devices = Vec::new();
397 for layer in 0..self.inner.num_layers(&config)? {
398 let device = mapper.device_for(layer, false).cloned();
399 layer_devices.push(device);
400 }
401 let dtype = mapper.get_min_dtype(dtype)?;
402
403 let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
406 if mapping_uses_cpu {
407 warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
408 paged_attn_config = None;
409 }
410
411 info!("Model config: {:?}", self.inner.get_config_repr(&config)?);
412 if crate::using_flash_attn() {
413 once_log_info("FlashAttention is enabled.");
414 }
415
416 let mut loading_isq = if self.config.imatrix.is_none()
418 && self.config.calibration_file.is_none()
419 && !device.is_cuda()
420 && self.config.write_uqff.is_none()
421 && in_situ_quant.is_some()
422 {
423 let predicates = self.inner.immediate_isq_predicates(&config)?;
424 info!("Applying ISQ to {in_situ_quant:?}");
425 if predicates.is_empty() {
426 warn!("No predicates for this model and ISQ setting detected. ISQ will not be applied to any weights!");
427 }
428 mistralrs_quant::set_immediate_isq(in_situ_quant, predicates);
429 false
430 } else {
431 in_situ_quant.is_some()
432 };
433
434 if let Some(ref topology) = self.config.topology {
435 loading_isq |= topology
436 .0
437 .iter()
438 .any(|layer| layer.as_ref().is_some_and(|layer| layer.isq.is_some()));
439 }
440
441 if self.config.imatrix.is_some() && self.config.calibration_file.is_some() {
442 anyhow::bail!(
443 "`imatrix` and `calibration_file` were both specified, this is not allowed."
444 );
445 }
446
447 let load_device = if !loading_isq || self.config.calibration_file.is_some() {
449 loading_isq = false;
450 device.clone()
451 } else {
452 Device::Cpu
453 };
454
455 let attention_mechanism = if paged_attn_config.is_some() {
456 AttentionImplementation::PagedAttention
457 } else {
458 AttentionImplementation::Eager
459 };
460
461 let multi_progress = Arc::new(MultiProgress::new());
462
463 let mut model = if use_nccl {
464 let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
465 dtype,
466 &device,
467 &available_devices,
468 silent,
469 &config,
470 loading_isq,
471 self.config.from_uqff.is_some(),
472 IsqOrganization::Default,
473 &*self.inner,
474 paths.as_ref(),
475 )?;
476
477 match self.kind {
479 ModelKind::Normal => vision_normal_model_loader_sharded!(
480 sharded_vb,
481 config,
482 self.inner,
483 mapper,
484 loading_isq,
485 device.clone(),
486 attention_mechanism,
487 multi_progress.clone(),
488 ),
489 _ => unreachable!(),
490 }
491 } else {
492 match self.kind {
493 ModelKind::Normal => vision_normal_model_loader!(
494 paths,
495 Some(dtype),
496 &load_device,
497 layer_devices.clone(),
498 config,
499 self.inner,
500 silent,
501 mapper,
502 loading_isq,
503 self.config.from_uqff.is_some(),
504 device.clone(),
505 attention_mechanism,
506 multi_progress,
507 ),
508 _ => unreachable!(),
509 }
510 };
511
512 let preprocessor_config: PreProcessorConfig = match paths.get_preprocessor_config().as_ref()
514 {
515 Some(preprocessor_config) => {
516 serde_json::from_str(&fs::read_to_string(preprocessor_config).unwrap()).unwrap()
517 }
518 None => PreProcessorConfig::default(),
519 };
520 let processor_config: Option<ProcessorConfig> = paths
521 .get_processor_config()
522 .as_ref()
523 .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
524
525 let processor = self.inner.get_processor(
526 &config,
527 processor_config,
528 preprocessor_config.clone(),
529 self.config.max_edge,
530 ); let tokenizer = get_tokenizer(
533 paths.get_tokenizer_filename(),
534 Some(processor.get_special_tokens()),
535 )?;
536
537 let gen_conf: Option<GenerationConfig> = paths.get_gen_conf_filename().map(|f| {
538 serde_json::from_str(&fs::read_to_string(f).unwrap())
539 .expect("bos_token_id/eos_token_id missing in generation_config.json")
540 });
541 let chat_template_explicit = paths
542 .get_chat_template_explicit()
543 .as_ref()
544 .map(|x| x.to_string_lossy().to_string());
545 let chat_template = get_chat_template(
546 paths,
547 self.jinja_explicit.as_ref(),
548 chat_template_explicit.as_ref(),
549 self.chat_template.as_ref(),
550 None,
551 );
552
553 if let Some(calibration_file) = &self.config.calibration_file {
554 let calibration_data = std::fs::read_to_string(calibration_file)?;
555 let tokens = tokenizer
557 .encode_fast(calibration_data, false)
558 .map_err(anyhow::Error::msg)?
559 .get_ids()
560 .to_vec();
561 info!(
562 "Collecting imatrix from calibration file `{}` of {} tokens.",
563 calibration_file.display(),
564 tokens.len()
565 );
566 let bos_toks = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();
567 let bos_tok_id = tokenizer
568 .token_to_id(&bos_toks[0])
569 .expect("Somehow the bos token is not present.");
570
571 model.begin_track_stats()?;
574
575 const CHUNK_SIZE: usize = 1024;
576 let n_chunks: usize = tokens.len().div_ceil(CHUNK_SIZE);
577 let start = Instant::now();
578 for (i, chunk) in tokens.chunks(CHUNK_SIZE).enumerate() {
579 let chunk = [vec![bos_tok_id], chunk.to_vec()].concat();
580 let chunk_len = chunk.len();
581
582 let start = Instant::now();
583 let inputs = make_prompt_chunk(
584 0,
585 vec![&chunk],
586 &[0],
587 &load_device,
588 None,
589 false,
590 None,
591 None,
592 )?;
593 let _ = model.forward(
594 &inputs.input,
595 None, &inputs.positions,
597 inputs.context_lens,
598 inputs.position_ids,
599 model.default_model_specific_args(&inputs.input),
600 None,
601 &inputs.flash_meta,
602 )?;
603 match model.cache_mut() {
604 EitherCache::Full(full) => {
605 for layer in &mut *full.lock() {
606 *layer = None
607 }
608 }
609 EitherCache::Normal(normal) => {
610 for layer in &mut *normal.lock().unwrap().0 {
611 layer.reset();
612 }
613 }
614 }
615 let end = Instant::now();
616 info!(
617 "Processed chunk {}/{n_chunks} ({chunk_len} tokens), {:.2}s",
618 i + 1,
619 end.duration_since(start).as_secs_f32()
620 );
621 }
622 load_device.synchronize()?;
623 let end = Instant::now();
624 info!(
625 "Finished collecting imatrix in {:.2}s",
626 end.duration_since(start).as_secs_f32()
627 );
628 }
629
630 if (loading_isq || self.config.topology.is_some()) && self.config.from_uqff.is_none() {
632 let imatrix_source = match (
633 self.config.imatrix.as_ref(),
634 self.config.calibration_file.is_some(),
635 ) {
636 (None, false) => None,
637 (Some(file), false) => Some(ImatrixDataSource::File(file)),
638 (None, true) => Some(ImatrixDataSource::Collected),
639 (Some(_), true) => unreachable!(),
640 };
641 model.quantize(
642 in_situ_quant,
643 device.clone(),
644 self.config.topology.as_ref(),
645 silent,
646 imatrix_source,
647 IsqOrganization::Default,
648 self.config.write_uqff.as_ref(),
649 UqffFullSer {
650 tokenizer: &tokenizer,
651 template_filename: paths.get_template_filename(),
652 generation_config: paths.get_gen_conf_filename(),
653 config: config.clone(),
654 processor_filename: paths.get_processor_config(),
655 preprocessor_filename: paths.get_preprocessor_config(),
656 },
657 Arc::new(MultiProgress::new()),
658 )?;
659 } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
660 model.load_from_artifacts(
661 device.clone(),
662 self.config.topology.as_ref(),
663 silent,
664 from_uqff,
665 )?;
666 }
667
668 let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
669 anyhow::ensure!(
670 !matches!(self.kind, ModelKind::Adapter { .. }),
671 "PagedAttention does not support adapter models."
672 );
673 let cache_config = calculate_cache_config(
674 paged_attn_config.mem_gpu,
675 paged_attn_config.mem_cpu,
676 paged_attn_config.block_size,
677 dtype,
678 model.config(),
679 &device,
680 &layer_devices,
681 silent,
682 )?;
683 let cache_engine =
684 CacheEngine::new(model.config(), &cache_config, dtype, &device, layer_devices)?;
685 (Some(cache_config), Some(cache_engine))
686 } else {
687 (None, None)
688 };
689
690 let max_seq_len = model.max_seq_len();
691 let llg_factory = build_llg_factory(tokenizer.clone())?;
692 let num_hidden_layers = match model.cache() {
693 EitherCache::Full(full) => full.lock().len(),
694 EitherCache::Normal(normal) => normal.lock().unwrap().0.len(),
695 };
696 let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
697 let sliding_window = model.config().sliding_window;
698 let model_metadata = Arc::new(model.config().clone());
699 Ok(Arc::new(Mutex::new(VisionPipeline {
700 model,
701 tokenizer: tokenizer.into(),
702 chat_template: Arc::new(chat_template),
703 model_id: self.model_id.clone(),
704 metadata: Arc::new(GeneralMetadata {
705 max_seq_len,
706 llg_factory: Some(llg_factory),
707 is_xlora: false,
708 num_hidden_layers,
709 eos_tok: eos,
710 kind: self.kind.clone(),
711 no_kv_cache: false,
712 no_prefix_cache: !self.inner.supports_prefix_cacher(&config),
713 activation_dtype: dtype,
714 sliding_window,
715 cache_config,
716 cache_engine,
717 prompt_chunksize: self.config.prompt_chunksize,
718 model_metadata: Some(model_metadata),
719 }),
720 processor,
721 prefixer: self.inner.prefixer(&config),
722 preprocessor_config: Arc::new(preprocessor_config),
723 topology: self.config.topology.clone(),
724 silent,
725 template_filename: paths.get_template_filename().clone(),
726 generation_config: paths.get_gen_conf_filename().cloned(),
727 config,
728 processor_filename: paths.get_processor_config().clone(),
729 preprocessor_filename: paths.get_preprocessor_config().clone(),
730 mapper: pipeline_mapper,
731 imatrix: self.config.imatrix.clone(),
732 })))
733 }
734
735 fn get_id(&self) -> String {
736 self.model_id.to_string()
737 }
738
739 fn get_kind(&self) -> ModelKind {
740 self.kind.clone()
741 }
742}
743
744impl PreProcessingMixin for VisionPipeline {
745 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
746 Some(self.chat_template.clone())
747 }
748 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
749 Some(self.preprocessor_config.clone())
750 }
751 fn get_processor(&self) -> Arc<dyn super::Processor> {
752 self.processor.clone()
753 }
754}
755
756impl IsqPipelineMixin for VisionPipeline {
757 fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
758 let device = self.device().clone();
759 self.model
760 .quantize(
761 Some(dtype),
762 device,
763 self.topology.as_ref(),
764 self.silent,
765 self.imatrix.as_ref().map(ImatrixDataSource::File),
766 IsqOrganization::Default,
767 None,
768 UqffFullSer {
769 tokenizer: &self.tokenizer,
770 template_filename: &self.template_filename,
771 generation_config: self.generation_config.as_ref(),
772 config: self.config.clone(),
773 processor_filename: &self.processor_filename,
774 preprocessor_filename: &self.preprocessor_filename,
775 },
776 Arc::new(MultiProgress::new()),
777 )
778 .map_err(anyhow::Error::msg)
779 }
780}
781
782impl CacheManagerMixin for VisionPipeline {
783 fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
784 if matches!(self.model.cache(), EitherCache::Full(_)) {
785 FullCacheManager.clone_in_cache(self, seqs, false)
786 } else {
787 NormalCacheManager.clone_in_cache(self, seqs, false)
788 }
789 }
790 fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
791 if matches!(self.model.cache(), EitherCache::Full(_)) {
792 FullCacheManager.clone_out_cache(self, seqs, false)
793 } else {
794 NormalCacheManager.clone_out_cache(self, seqs, false)
795 }
796 }
797 fn set_none_cache(
798 &self,
799 seqs: &mut [&mut Sequence],
800 reset_non_granular: bool,
801 modify_draft_cache: bool,
802
803 load_preallocated_cache: bool,
804 ) {
805 if matches!(self.model.cache(), EitherCache::Full(_)) {
806 FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false);
807 } else {
808 NormalCacheManager.set_none_cache(
809 self,
810 seqs,
811 modify_draft_cache,
812 load_preallocated_cache,
813 );
814 }
815 if reset_non_granular {
816 self.reset_non_granular_state()
817 }
818 }
819 fn cache(&self) -> &EitherCache {
820 self.model.cache()
821 }
822}
823
824impl MetadataMixin for VisionPipeline {
825 fn device(&self) -> Device {
826 self.model.device().clone()
827 }
828 fn get_metadata(&self) -> Arc<GeneralMetadata> {
829 self.metadata.clone()
830 }
831 fn name(&self) -> String {
832 self.model_id.clone()
833 }
834 fn reset_non_granular_state(&self) {}
835 fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
836 Some(self.tokenizer.clone())
837 }
838 fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
839 Some(&*self.mapper)
840 }
841}
842
843#[async_trait::async_trait]
844impl Pipeline for VisionPipeline {
845 fn forward_inputs(
846 &mut self,
847 inputs: Box<dyn Any>,
848 return_raw_logits: bool,
849 ) -> candle_core::Result<ForwardInputsResult> {
850 let ModelInputs {
851 input_ids,
852 seqlen_offsets,
853 context_lens,
854 position_ids,
855 pixel_values,
856 model_specific_args,
857 paged_attn_meta,
858 flash_meta,
859 } = *inputs.downcast::<ModelInputs>().expect("Downcast failed.");
860 let metadata = self.get_metadata();
861 let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
862 (Some(engine), Some(meta)) => Some((engine.get_kv_cache().clone(), meta)),
863 (Some(_), None) => {
864 candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
866 }
867 (None, Some(_)) => {
868 candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
870 }
871 (None, None) => None,
872 };
873 let logits = self.model.forward(
874 &input_ids,
875 pixel_values,
876 &seqlen_offsets,
877 context_lens,
878 position_ids,
879 model_specific_args,
880 paged_attn_meta,
881 &flash_meta,
882 )?;
883 if return_raw_logits {
884 Ok(ForwardInputsResult::RawLogits { logits })
885 } else {
886 Ok(ForwardInputsResult::CausalGeneration { logits })
887 }
888 }
889 async fn sample_causal_gen(
890 &self,
891 seqs: &mut [&mut Sequence],
892 logits: Vec<Tensor>,
893 prefix_cacher: &mut PrefixCacheManagerV2,
894 disable_eos_stop: bool,
895 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
896 ) -> Result<(), candle_core::Error> {
897 sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
898 }
899 fn category(&self) -> ModelCategory {
900 ModelCategory::Vision {
901 prefixer: self.prefixer.clone(),
902 }
903 }
904}
905
906impl AnyMoePipelineMixin for VisionPipeline {
907 fn amoe_finish_training(&mut self, gate_model_id: Option<String>) -> candle_core::Result<()> {
908 self.model.finish_training(gate_model_id)
909 }
910 fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
911 self.model.get_vars()
912 }
913 fn amoe_base_model_trainable_params(&self) -> usize {
914 self.model.trainable_params()
915 }
916 fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
917 self.model.take_cached_gating_outputs()
918 }
919 fn amoe_create_layers(
920 &mut self,
921 model_ids: Vec<String>,
922 token: &TokenSource,
923 revision: Option<String>,
924 match_regex: &str,
925 config: crate::amoe::AnyMoeConfig,
926 dtype: candle_core::DType,
927 dev: &Device,
928 (prefix, mlp): (String, String),
929 layers: Vec<usize>,
930 expert_type: AnyMoeExpertType,
931 silent: bool,
932 gate_model_id: Option<String>,
933 ) -> candle_core::Result<()> {
934 let mut vbs = Vec::new();
935 let regex = Regex::new(match_regex).map_err(candle_core::Error::msg)?;
937 for model_id in model_ids {
938 let model_id_str = &model_id;
939 let model_id = Path::new(&model_id);
940
941 let api = {
942 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
943 let mut api = ApiBuilder::from_cache(cache)
944 .with_progress(!silent)
945 .with_token(get_token(token).map_err(candle_core::Error::msg)?);
946 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
947 api = api.with_cache_dir(x.into());
948 }
949 api.build().map_err(candle_core::Error::msg)?
950 };
951 let revision = revision.clone().unwrap_or("main".to_string());
952 let api = api.repo(Repo::with_revision(
953 model_id_str.clone(),
954 RepoType::Model,
955 revision.clone(),
956 ));
957
958 let mut filenames = vec![];
959 for rfilename in api_dir_list!(api, model_id).filter(|x| x.ends_with(".safetensors")) {
960 filenames.push(api_get_file!(api, &rfilename, model_id));
961 }
962
963 let regex = regex.clone();
964 let match_regex_clone = match_regex.to_string();
965 let layers_clone = layers.clone();
966 let vb = from_mmaped_safetensors(
967 filenames,
968 vec![],
969 Some(dtype),
970 dev,
971 vec![None],
972 silent,
973 None,
974 move |key| {
975 if regex.is_match(&key) {
976 let last_layer_idx = key.find(&match_regex_clone).unwrap() - 1;
979 let first_layer_idx = key[..last_layer_idx].rfind('.').unwrap();
980 let layer_n = key[first_layer_idx + 1..last_layer_idx]
981 .parse::<usize>()
982 .unwrap();
983 layers_clone.contains(&layer_n) || layers_clone.is_empty()
984 } else {
985 false
986 }
987 },
988 Arc::new(|_| DeviceForLoadTensor::Base),
989 )?;
990 vbs.push(vb);
991 }
992
993 let gate_vb = if let Some(gate_model_id) = gate_model_id {
994 let model_id_str = &gate_model_id;
995 let model_id = Path::new(&gate_model_id);
996
997 let api = {
998 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
999 let mut api = ApiBuilder::from_cache(cache)
1000 .with_progress(!silent)
1001 .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1002 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1003 api = api.with_cache_dir(x.into());
1004 }
1005 api.build().map_err(candle_core::Error::msg)?
1006 };
1007 let revision = revision.clone().unwrap_or("main".to_string());
1008 let api = api.repo(Repo::with_revision(
1009 model_id_str.clone(),
1010 RepoType::Model,
1011 revision.clone(),
1012 ));
1013
1014 let mut gate_filenames = vec![];
1015 for rfilename in api_dir_list!(api, model_id).filter(|x| x.ends_with(".safetensors")) {
1016 gate_filenames.push(api_get_file!(api, &rfilename, model_id));
1017 }
1018 assert_eq!(
1019 gate_filenames.len(),
1020 1,
1021 "Gate model ID must contain only one .safetensors file"
1022 );
1023
1024 let vb = from_mmaped_safetensors(
1025 gate_filenames.clone(),
1026 vec![],
1027 Some(dtype),
1028 dev,
1029 vec![None],
1030 silent,
1031 None,
1032 |_| true,
1033 Arc::new(|_| DeviceForLoadTensor::Base),
1034 )?;
1035 info!(
1036 "Loaded gating layers from `{}`",
1037 gate_filenames[0].display()
1038 );
1039 Some(vb)
1040 } else {
1041 None
1042 };
1043
1044 self.model
1045 .create_anymoe_layers(vbs, config, (prefix, mlp), layers, expert_type, gate_vb)
1046 }
1047 fn amoe_supported(&self) -> bool {
1048 self.model.amoe_supported()
1049 }
1050}