1use super::isq::ImatrixDataSource;
2use super::isq::UqffFullSer;
3use super::{
4 get_model_paths, get_xlora_paths, AdapterKind, AnyMoePipelineMixin, AutoVisionLoader,
5 CacheManager, CacheManagerMixin, EitherCache, ForwardInputsResult, Gemma3Loader,
6 GeneralMetadata, IsqPipelineMixin, Loader, MetadataMixin, MiniCpmOLoader, ModelCategory,
7 ModelKind, ModelPaths, MultimodalPromptPrefixer, Phi4MMLoader, PreProcessingMixin, Processor,
8 Qwen2VLLoader, TokenSource, VLlama4Loader, VLlamaLoader, VisionModel, VisionModelLoader,
9};
10use super::{
11 Idefics2Loader, Idefics3Loader, LLaVALoader, LLaVANextLoader, Mistral3Loader, Phi3VLoader,
12 Qwen2_5VLLoader, VisionLoaderType,
13};
14use crate::device_map::{self, DeviceMapper};
15use crate::distributed::{self, WorkerTransferData};
16use crate::kv_cache::{FullCacheManager, NormalCacheManager};
17use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine};
18use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
19use crate::pipeline::llg::build_llg_factory;
20use crate::pipeline::loaders::auto_device_map;
21use crate::pipeline::loaders::QuantizationConfigShim;
22use crate::pipeline::sampling::sample_and_add_toks;
23use crate::pipeline::text_models_inputs_processor::make_prompt_chunk;
24use crate::pipeline::{get_chat_template, ChatTemplate, IsqOrganization, LocalModelPaths};
25use crate::prefix_cacher::PrefixCacheManagerV2;
26use crate::sequence::Sequence;
27use crate::utils::tokenizer::get_tokenizer;
28use crate::utils::varbuilder_utils::DeviceForLoadTensor;
29use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
30use crate::vision_models::preprocessor_config::PreProcessorConfig;
31use crate::vision_models::processor_config::ProcessorConfig;
32use crate::vision_models::ModelInputs;
33use crate::{
34 api_dir_list, api_get_file, get_paths, get_uqff_paths, vision_normal_model_loader,
35 vision_normal_model_loader_sharded, AnyMoeExpertType, DeviceMapSetting, Ordering,
36 PagedAttentionConfig, Pipeline, Topology, TryIntoDType, GLOBAL_HF_CACHE,
37};
38use anyhow::Result;
39use candle_core::{Device, Tensor, Var};
40use hf_hub::Cache;
41use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
42use indicatif::MultiProgress;
43use mistralrs_quant::log::once_log_info;
44use mistralrs_quant::{AfqLayer, GgufMatMul, HqqLayer, IsqType, QuantizedSerdeType};
45use rand_isaac::Isaac64Rng;
46use regex_automata::meta::Regex;
47use std::any::Any;
48use std::borrow::Cow;
49use std::num::NonZeroUsize;
50use std::path::{Path, PathBuf};
51use std::str::FromStr;
52use std::sync::{Arc, RwLock};
53use std::time::Instant;
54use std::{env, fs};
55use tokenizers::Tokenizer;
56use tokio::sync::Mutex;
57use tracing::{info, warn};
58
59pub struct VisionPipeline {
60 model: Box<dyn VisionModel + Send + Sync>,
61 tokenizer: Arc<Tokenizer>,
62 chat_template: Arc<ChatTemplate>,
63 model_id: String,
64 metadata: Arc<GeneralMetadata>,
65 processor: Arc<dyn Processor + Send + Sync>,
66 preprocessor_config: Arc<PreProcessorConfig>,
67 topology: Option<Topology>,
68 silent: bool,
69 prefixer: Arc<dyn MultimodalPromptPrefixer>,
70 mapper: Box<dyn DeviceMapper + Send + Sync>,
71
72 template_filename: Option<PathBuf>,
74 generation_config: Option<PathBuf>,
75 config: String,
76 processor_filename: Option<PathBuf>,
77 preprocessor_filename: Option<PathBuf>,
78 imatrix: Option<PathBuf>,
79}
80
81pub struct VisionLoader {
83 inner: Box<dyn VisionModelLoader>,
84 model_id: String,
85 config: VisionSpecificConfig,
86 kind: ModelKind,
87 chat_template: Option<String>,
88 tokenizer_json: Option<String>,
89 xlora_model_id: Option<String>,
90 xlora_order: Option<Ordering>,
91 token_source: RwLock<Option<TokenSource>>,
92 revision: RwLock<Option<String>>,
93 from_uqff: RwLock<Option<Vec<PathBuf>>>,
94 jinja_explicit: Option<String>,
95 hf_cache_path: Option<PathBuf>,
96 lora_adapter_ids: Option<Vec<String>>,
97}
98
99#[derive(Default)]
100pub struct VisionLoaderBuilder {
102 model_id: Option<String>,
103 config: VisionSpecificConfig,
104 kind: ModelKind,
105 chat_template: Option<String>,
106 tokenizer_json: Option<String>,
107 jinja_explicit: Option<String>,
108 hf_cache_path: Option<PathBuf>,
109 lora_adapter_ids: Option<Vec<String>>,
110}
111
112#[derive(Clone, Default)]
113pub struct VisionSpecificConfig {
115 pub prompt_chunksize: Option<NonZeroUsize>,
116 pub topology: Option<Topology>,
117 pub write_uqff: Option<PathBuf>,
118 pub from_uqff: Option<Vec<PathBuf>>,
119 pub max_edge: Option<u32>,
120 pub imatrix: Option<PathBuf>,
121 pub calibration_file: Option<PathBuf>,
122 pub hf_cache_path: Option<PathBuf>,
123}
124
125impl VisionLoaderBuilder {
126 pub fn new(
127 config: VisionSpecificConfig,
128 chat_template: Option<String>,
129 tokenizer_json: Option<String>,
130 model_id: Option<String>,
131 jinja_explicit: Option<String>,
132 ) -> Self {
133 Self {
134 config,
135 chat_template,
136 tokenizer_json,
137 model_id,
138 jinja_explicit,
139 kind: ModelKind::Normal,
140 hf_cache_path: None,
141 ..Default::default()
142 }
143 }
144
145 pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
146 self.hf_cache_path = Some(hf_cache_path);
147 self
148 }
149
150 pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
151 self.kind = ModelKind::Adapter {
152 adapter: AdapterKind::Lora,
153 };
154 self.lora_adapter_ids = Some(lora_adapter_ids);
155 self
156 }
157
158 pub fn build(self, loader: Option<VisionLoaderType>) -> Box<dyn Loader> {
159 let loader: Box<dyn VisionModelLoader> = match loader {
160 Some(VisionLoaderType::Phi3V) => Box::new(Phi3VLoader),
161 Some(VisionLoaderType::Idefics2) => Box::new(Idefics2Loader),
162 Some(VisionLoaderType::LLaVANext) => Box::new(LLaVANextLoader),
163 Some(VisionLoaderType::LLaVA) => Box::new(LLaVALoader),
164 Some(VisionLoaderType::VLlama) => Box::new(VLlamaLoader),
165 Some(VisionLoaderType::Qwen2VL) => Box::new(Qwen2VLLoader),
166 Some(VisionLoaderType::Idefics3) => Box::new(Idefics3Loader),
167 Some(VisionLoaderType::MiniCpmO) => Box::new(MiniCpmOLoader),
168 Some(VisionLoaderType::Phi4MM) => Box::new(Phi4MMLoader),
169 Some(VisionLoaderType::Qwen2_5VL) => Box::new(Qwen2_5VLLoader),
170 Some(VisionLoaderType::Gemma3) => Box::new(Gemma3Loader),
171 Some(VisionLoaderType::Mistral3) => Box::new(Mistral3Loader),
172 Some(VisionLoaderType::Llama4) => Box::new(VLlama4Loader),
173 None => Box::new(AutoVisionLoader),
174 };
175 Box::new(VisionLoader {
176 inner: loader,
177 model_id: self.model_id.unwrap(),
178 config: self.config,
179 kind: self.kind,
180 chat_template: self.chat_template,
181 tokenizer_json: self.tokenizer_json,
182 xlora_model_id: None,
183 xlora_order: None,
184 jinja_explicit: self.jinja_explicit,
185 token_source: RwLock::new(None),
186 revision: RwLock::new(None),
187 from_uqff: RwLock::new(None),
188 hf_cache_path: self.hf_cache_path,
189 lora_adapter_ids: self.lora_adapter_ids,
190 })
191 }
192}
193
194impl Loader for VisionLoader {
195 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
196 fn load_model_from_hf(
197 &self,
198 revision: Option<String>,
199 token_source: TokenSource,
200 dtype: &dyn TryIntoDType,
201 device: &Device,
202 silent: bool,
203 mapper: DeviceMapSetting,
204 in_situ_quant: Option<IsqType>,
205 paged_attn_config: Option<PagedAttentionConfig>,
206 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
207 let cache = self
208 .hf_cache_path
209 .clone()
210 .map(Cache::new)
211 .unwrap_or_default();
212 GLOBAL_HF_CACHE.get_or_init(|| cache);
213
214 let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths!(
215 LocalModelPaths,
216 &token_source,
217 revision.clone(),
218 self,
219 None,
220 None,
221 silent,
222 self.config.from_uqff.is_some()
223 );
224 if let Some(from_uqff) = self.config.from_uqff.clone() {
225 *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
226 }
227 *self
228 .token_source
229 .write()
230 .expect("Failed to write to token source") = Some(token_source);
231 *self.revision.write().expect("Failed to write to revision") = revision;
232 self.load_model_from_path(
233 &paths?,
234 dtype,
235 device,
236 silent,
237 mapper,
238 in_situ_quant,
239 paged_attn_config,
240 )
241 }
242
243 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
244 fn load_model_from_path(
245 &self,
246 paths: &Box<dyn ModelPaths>,
247 dtype: &dyn TryIntoDType,
248 device: &Device,
249 silent: bool,
250 mut mapper: DeviceMapSetting,
251 mut in_situ_quant: Option<IsqType>,
252 mut paged_attn_config: Option<PagedAttentionConfig>,
253 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
254 let config = std::fs::read_to_string(paths.get_config_filename())?;
255
256 if !self.inner.supports_paged_attention(&config) {
257 paged_attn_config = None;
258 }
259
260 let use_nccl = mistralrs_quant::distributed::use_nccl();
261
262 let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
263 let payload: WorkerTransferData = serde_json::from_str(&payload)?;
264 let WorkerTransferData::Init { id: _, worker_rank } = payload;
265 vec![candle_core::Device::new_cuda_with_stream(worker_rank + 1)?]
266 } else if use_nccl {
267 vec![candle_core::Device::new_cuda_with_stream(0)?]
268 } else {
269 device_map::get_all_similar_devices(device)?
270 };
271 let device = if use_nccl {
272 available_devices[0].clone()
273 } else {
274 device.clone()
275 };
276
277 if use_nccl {
279 mapper = DeviceMapSetting::DummyNccl {
280 nm_device: available_devices[0].clone(),
281 };
282 } else if let DeviceMapSetting::Auto(mut params) = mapper.clone() {
283 params = params.maybe_promote_to_vision();
285
286 let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
288
289 if QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)? != 1 {
291 in_situ_quant = None;
292 }
293
294 let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
297 if let Some(serialized) = &*self.from_uqff.read().unwrap() {
298 let weight_pack_factor = {
299 let ser_artifacts = unsafe {
300 candle_core::safetensors::MmapedSafetensors::multi(serialized)?
301 };
302 let mut total_pack_factors = 0;
303 let total_tensors = ser_artifacts.tensors().len();
304 for (_, artifact) in ser_artifacts.tensors() {
305 let artifact = artifact.data();
306 let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
308 let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
309 {
310 QuantizedSerdeType::Hqq => {
311 HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
312 .pack_factor(dtype)
313 }
314 QuantizedSerdeType::Gguf => {
315 GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
316 .pack_factor(dtype)
317 }
318 QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
319 QuantizedSerdeType::Unquant => 1,
320 QuantizedSerdeType::Afq => {
321 AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
322 .pack_factor(dtype)
323 }
324 };
325 total_pack_factors += pack_factor;
326 }
327
328 total_pack_factors / total_tensors
329 };
330
331 let layer_sizes_in_bytes =
332 self.inner
333 .layer_sizes_in_bytes(&config, dtype, weight_pack_factor)?;
334 let non_mapped_size_in_bytes =
335 self.inner
336 .non_mapped_size_in_bytes(&config, dtype, weight_pack_factor)?;
337 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
338 (
339 layer_sizes_in_bytes,
340 non_mapped_size_in_bytes,
341 layer_sizes_sum + non_mapped_size_in_bytes,
342 )
343 } else if let Some(isq) = in_situ_quant {
344 let weight_pack_factor = isq.pack_factor(dtype);
345 let layer_sizes_in_bytes =
346 self.inner
347 .layer_sizes_in_bytes(&config, dtype, weight_pack_factor)?;
348 let non_mapped_size_in_bytes =
349 self.inner
350 .non_mapped_size_in_bytes(&config, dtype, weight_pack_factor)?;
351 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
352 (
353 layer_sizes_in_bytes,
354 non_mapped_size_in_bytes,
355 layer_sizes_sum + non_mapped_size_in_bytes,
356 )
357 } else {
358 let weight_pack_factor =
360 QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)?;
361 let layer_sizes_in_bytes =
362 self.inner
363 .layer_sizes_in_bytes(&config, dtype, weight_pack_factor)?;
364 let non_mapped_size_in_bytes =
365 self.inner
366 .non_mapped_size_in_bytes(&config, dtype, weight_pack_factor)?;
367 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
368 (
369 layer_sizes_in_bytes,
370 non_mapped_size_in_bytes,
371 layer_sizes_sum + non_mapped_size_in_bytes,
372 )
373 };
374
375 let new = auto_device_map::get_device_layers(
376 &*self.inner,
377 &config,
378 self.inner.num_layers(&config)?,
379 layer_sizes_in_bytes,
380 non_mapped_size_in_bytes,
381 total_model_size_in_bytes,
382 &available_devices,
383 dtype,
384 ¶ms,
385 params.max_seq_len(),
386 paged_attn_config.as_ref(),
387 )?;
388 mapper = DeviceMapSetting::Map(new);
389 }
390
391 let pipeline_mapper = mapper.into_mapper(
392 self.inner.num_layers(&config)?,
393 &device,
394 self.config.topology.as_ref(),
395 )?;
396 let mapper = mapper.into_mapper(
397 self.inner.num_layers(&config)?,
398 &device,
399 self.config.topology.as_ref(),
400 )?;
401 let mut layer_devices = Vec::new();
402 for layer in 0..self.inner.num_layers(&config)? {
403 let device = mapper.device_for(layer, false).cloned();
404 layer_devices.push(device);
405 }
406 let dtype = mapper.get_min_dtype(dtype)?;
407
408 let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
411 if mapping_uses_cpu {
412 warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
413 paged_attn_config = None;
414 }
415
416 info!("Model config: {:?}", self.inner.get_config_repr(&config)?);
417 if crate::using_flash_attn() {
418 once_log_info("FlashAttention is enabled.");
419 }
420
421 let mut loading_isq = if self.config.imatrix.is_none()
423 && self.config.calibration_file.is_none()
424 && !device.is_cuda()
425 && self.config.write_uqff.is_none()
426 && in_situ_quant.is_some()
427 {
428 let predicates = self.inner.immediate_isq_predicates(&config)?;
429 info!("Applying ISQ to {in_situ_quant:?}");
430 if predicates.is_empty() {
431 warn!("No predicates for this model and ISQ setting detected. ISQ will not be applied to any weights!");
432 }
433 mistralrs_quant::set_immediate_isq(in_situ_quant, predicates);
434 false
435 } else {
436 in_situ_quant.is_some()
437 };
438
439 if let Some(ref topology) = self.config.topology {
440 loading_isq |= topology
441 .0
442 .iter()
443 .any(|layer| layer.as_ref().is_some_and(|layer| layer.isq.is_some()));
444 }
445
446 if self.config.imatrix.is_some() && self.config.calibration_file.is_some() {
447 anyhow::bail!(
448 "`imatrix` and `calibration_file` were both specified, this is not allowed."
449 );
450 }
451
452 let load_device = if !loading_isq || self.config.calibration_file.is_some() {
454 loading_isq = false;
455 device.clone()
456 } else {
457 Device::Cpu
458 };
459
460 let attention_mechanism = if paged_attn_config.is_some() {
461 AttentionImplementation::PagedAttention
462 } else {
463 AttentionImplementation::Eager
464 };
465
466 let multi_progress = Arc::new(MultiProgress::new());
467
468 let mut model = if use_nccl {
469 let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
470 dtype,
471 &device,
472 &available_devices,
473 silent,
474 &config,
475 loading_isq,
476 self.config.from_uqff.is_some(),
477 IsqOrganization::Default,
478 &*self.inner,
479 paths.as_ref(),
480 )?;
481
482 match self.kind {
484 ModelKind::Normal => vision_normal_model_loader_sharded!(
485 sharded_vb,
486 config,
487 self.inner,
488 mapper,
489 loading_isq,
490 device.clone(),
491 attention_mechanism,
492 multi_progress.clone(),
493 ),
494 _ => unreachable!(),
495 }
496 } else {
497 match self.kind {
498 ModelKind::Normal => vision_normal_model_loader!(
499 paths,
500 Some(dtype),
501 &load_device,
502 layer_devices.clone(),
503 config,
504 self.inner,
505 silent,
506 mapper,
507 loading_isq,
508 self.config.from_uqff.is_some(),
509 device.clone(),
510 attention_mechanism,
511 multi_progress,
512 ),
513 _ => unreachable!(),
514 }
515 };
516
517 let preprocessor_config: PreProcessorConfig = match paths.get_preprocessor_config().as_ref()
519 {
520 Some(preprocessor_config) => {
521 serde_json::from_str(&fs::read_to_string(preprocessor_config).unwrap()).unwrap()
522 }
523 None => PreProcessorConfig::default(),
524 };
525 let processor_config: Option<ProcessorConfig> = paths
526 .get_processor_config()
527 .as_ref()
528 .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
529
530 let processor = self.inner.get_processor(
531 &config,
532 processor_config,
533 preprocessor_config.clone(),
534 self.config.max_edge,
535 ); let tokenizer = get_tokenizer(
538 paths.get_tokenizer_filename(),
539 Some(processor.get_special_tokens()),
540 )?;
541
542 let gen_conf: Option<GenerationConfig> = paths
543 .get_gen_conf_filename()
544 .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
545 let chat_template_explicit = paths
546 .get_chat_template_explicit()
547 .as_ref()
548 .map(|x| x.to_string_lossy().to_string());
549 let chat_template = get_chat_template(
550 paths,
551 self.jinja_explicit.as_ref(),
552 chat_template_explicit.as_ref(),
553 self.chat_template.as_ref(),
554 None,
555 );
556
557 if let Some(calibration_file) = &self.config.calibration_file {
558 let calibration_data = std::fs::read_to_string(calibration_file)?;
559 let tokens = tokenizer
561 .encode_fast(calibration_data, false)
562 .map_err(anyhow::Error::msg)?
563 .get_ids()
564 .to_vec();
565 info!(
566 "Collecting imatrix from calibration file `{}` of {} tokens.",
567 calibration_file.display(),
568 tokens.len()
569 );
570 let bos_toks = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();
571 let bos_tok_id = tokenizer
572 .token_to_id(&bos_toks[0])
573 .expect("Somehow the bos token is not present.");
574
575 model.begin_track_stats()?;
578
579 const CHUNK_SIZE: usize = 1024;
580 let n_chunks: usize = tokens.len().div_ceil(CHUNK_SIZE);
581 let start = Instant::now();
582 for (i, chunk) in tokens.chunks(CHUNK_SIZE).enumerate() {
583 let chunk = [vec![bos_tok_id], chunk.to_vec()].concat();
584 let chunk_len = chunk.len();
585
586 let start = Instant::now();
587 let inputs = make_prompt_chunk(
588 0,
589 vec![&chunk],
590 &[0],
591 &load_device,
592 None,
593 false,
594 None,
595 None,
596 )?;
597 let _ = model.forward(
598 &inputs.input,
599 None, &inputs.positions,
601 inputs.context_lens,
602 inputs.position_ids,
603 model.default_model_specific_args(&inputs.input),
604 None,
605 &inputs.flash_meta,
606 )?;
607 match model.cache_mut() {
608 EitherCache::Full(full) => {
609 for layer in &mut *full.lock() {
610 *layer = None
611 }
612 }
613 EitherCache::Normal(normal) => {
614 for layer in &mut *normal.lock().unwrap().0 {
615 layer.reset();
616 }
617 }
618 }
619 let end = Instant::now();
620 info!(
621 "Processed chunk {}/{n_chunks} ({chunk_len} tokens), {:.2}s",
622 i + 1,
623 end.duration_since(start).as_secs_f32()
624 );
625 }
626 load_device.synchronize()?;
627 let end = Instant::now();
628 info!(
629 "Finished collecting imatrix in {:.2}s",
630 end.duration_since(start).as_secs_f32()
631 );
632 }
633
634 if (loading_isq || self.config.topology.is_some()) && self.config.from_uqff.is_none() {
636 let imatrix_source = match (
637 self.config.imatrix.as_ref(),
638 self.config.calibration_file.is_some(),
639 ) {
640 (None, false) => None,
641 (Some(file), false) => Some(ImatrixDataSource::File(file)),
642 (None, true) => Some(ImatrixDataSource::Collected),
643 (Some(_), true) => unreachable!(),
644 };
645 model.quantize(
646 in_situ_quant,
647 device.clone(),
648 self.config.topology.as_ref(),
649 silent,
650 imatrix_source,
651 IsqOrganization::Default,
652 self.config.write_uqff.as_ref(),
653 UqffFullSer {
654 tokenizer: &tokenizer,
655 template_filename: paths.get_template_filename(),
656 generation_config: paths.get_gen_conf_filename(),
657 config: config.clone(),
658 processor_filename: paths.get_processor_config(),
659 preprocessor_filename: paths.get_preprocessor_config(),
660 },
661 Arc::new(MultiProgress::new()),
662 )?;
663 } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
664 model.load_from_artifacts(
665 device.clone(),
666 self.config.topology.as_ref(),
667 silent,
668 from_uqff,
669 )?;
670 }
671
672 let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
673 anyhow::ensure!(
674 !matches!(self.kind, ModelKind::Adapter { .. }),
675 "PagedAttention does not support adapter models."
676 );
677 let cache_config = calculate_cache_config(
678 paged_attn_config.mem_gpu,
679 paged_attn_config.mem_cpu,
680 paged_attn_config.block_size,
681 dtype,
682 paged_attn_config.cache_type,
683 model.config(),
684 &device,
685 &layer_devices,
686 silent,
687 )?;
688 let cache_engine =
689 CacheEngine::new(model.config(), &cache_config, dtype, &device, layer_devices)?;
690 (Some(cache_config), Some(cache_engine))
691 } else {
692 (None, None)
693 };
694
695 let max_seq_len = model.max_seq_len();
696 let llg_factory = build_llg_factory(tokenizer.clone())?;
697 let num_hidden_layers = match model.cache() {
698 EitherCache::Full(full) => full.lock().len(),
699 EitherCache::Normal(normal) => normal.lock().unwrap().0.len(),
700 };
701 let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
702 let sliding_window = model.config().sliding_window;
703 let model_metadata = Arc::new(model.config().clone());
704 Ok(Arc::new(Mutex::new(VisionPipeline {
705 model,
706 tokenizer: tokenizer.into(),
707 chat_template: Arc::new(chat_template),
708 model_id: self.model_id.clone(),
709 metadata: Arc::new(GeneralMetadata {
710 max_seq_len,
711 llg_factory: Some(llg_factory),
712 is_xlora: false,
713 num_hidden_layers,
714 eos_tok: eos,
715 kind: self.kind.clone(),
716 no_kv_cache: false,
717 no_prefix_cache: !self.inner.supports_prefix_cacher(&config),
718 activation_dtype: dtype,
719 sliding_window,
720 cache_config,
721 cache_engine,
722 prompt_chunksize: self.config.prompt_chunksize,
723 model_metadata: Some(model_metadata),
724 modalities: self.inner.modalities(&config)?,
725 }),
726 processor,
727 prefixer: self.inner.prefixer(&config),
728 preprocessor_config: Arc::new(preprocessor_config),
729 topology: self.config.topology.clone(),
730 silent,
731 template_filename: paths.get_template_filename().clone(),
732 generation_config: paths.get_gen_conf_filename().cloned(),
733 config,
734 processor_filename: paths.get_processor_config().clone(),
735 preprocessor_filename: paths.get_preprocessor_config().clone(),
736 mapper: pipeline_mapper,
737 imatrix: self.config.imatrix.clone(),
738 })))
739 }
740
741 fn get_id(&self) -> String {
742 self.model_id.to_string()
743 }
744
745 fn get_kind(&self) -> ModelKind {
746 self.kind.clone()
747 }
748}
749
750impl PreProcessingMixin for VisionPipeline {
751 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
752 Some(self.chat_template.clone())
753 }
754 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
755 Some(self.preprocessor_config.clone())
756 }
757 fn get_processor(&self) -> Arc<dyn super::Processor> {
758 self.processor.clone()
759 }
760}
761
762impl IsqPipelineMixin for VisionPipeline {
763 fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
764 let device = self.device().clone();
765 self.model
766 .quantize(
767 Some(dtype),
768 device,
769 self.topology.as_ref(),
770 self.silent,
771 self.imatrix.as_ref().map(ImatrixDataSource::File),
772 IsqOrganization::Default,
773 None,
774 UqffFullSer {
775 tokenizer: &self.tokenizer,
776 template_filename: &self.template_filename,
777 generation_config: self.generation_config.as_ref(),
778 config: self.config.clone(),
779 processor_filename: &self.processor_filename,
780 preprocessor_filename: &self.preprocessor_filename,
781 },
782 Arc::new(MultiProgress::new()),
783 )
784 .map_err(anyhow::Error::msg)
785 }
786}
787
788impl CacheManagerMixin for VisionPipeline {
789 fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
790 if matches!(self.model.cache(), EitherCache::Full(_)) {
791 FullCacheManager.clone_in_cache(self, seqs, false)
792 } else {
793 NormalCacheManager.clone_in_cache(self, seqs, false)
794 }
795 }
796 fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
797 if matches!(self.model.cache(), EitherCache::Full(_)) {
798 FullCacheManager.clone_out_cache(self, seqs, false)
799 } else {
800 NormalCacheManager.clone_out_cache(self, seqs, false)
801 }
802 }
803 fn set_none_cache(
804 &self,
805 seqs: &mut [&mut Sequence],
806 reset_non_granular: bool,
807 modify_draft_cache: bool,
808
809 load_preallocated_cache: bool,
810 ) {
811 if matches!(self.model.cache(), EitherCache::Full(_)) {
812 FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false);
813 } else {
814 NormalCacheManager.set_none_cache(
815 self,
816 seqs,
817 modify_draft_cache,
818 load_preallocated_cache,
819 );
820 }
821 if reset_non_granular {
822 self.reset_non_granular_state()
823 }
824 }
825 fn cache(&self) -> &EitherCache {
826 self.model.cache()
827 }
828}
829
830impl MetadataMixin for VisionPipeline {
831 fn device(&self) -> Device {
832 self.model.device().clone()
833 }
834 fn get_metadata(&self) -> Arc<GeneralMetadata> {
835 self.metadata.clone()
836 }
837 fn name(&self) -> String {
838 self.model_id.clone()
839 }
840 fn reset_non_granular_state(&self) {}
841 fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
842 Some(self.tokenizer.clone())
843 }
844 fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
845 Some(&*self.mapper)
846 }
847}
848
849#[async_trait::async_trait]
850impl Pipeline for VisionPipeline {
851 fn forward_inputs(
852 &mut self,
853 inputs: Box<dyn Any>,
854 return_raw_logits: bool,
855 ) -> candle_core::Result<ForwardInputsResult> {
856 let ModelInputs {
857 input_ids,
858 seqlen_offsets,
859 context_lens,
860 position_ids,
861 pixel_values,
862 model_specific_args,
863 paged_attn_meta,
864 flash_meta,
865 } = *inputs.downcast::<ModelInputs>().expect("Downcast failed.");
866 let metadata = self.get_metadata();
867 let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
868 (Some(engine), Some(meta)) => Some((engine.get_kv_cache().clone(), meta)),
869 (Some(_), None) => {
870 candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
872 }
873 (None, Some(_)) => {
874 candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
876 }
877 (None, None) => None,
878 };
879 let logits = self.model.forward(
880 &input_ids,
881 pixel_values,
882 &seqlen_offsets,
883 context_lens,
884 position_ids,
885 model_specific_args,
886 paged_attn_meta,
887 &flash_meta,
888 )?;
889 if return_raw_logits {
890 Ok(ForwardInputsResult::RawLogits { logits })
891 } else {
892 Ok(ForwardInputsResult::CausalGeneration { logits })
893 }
894 }
895 async fn sample_causal_gen(
896 &self,
897 seqs: &mut [&mut Sequence],
898 logits: Vec<Tensor>,
899 prefix_cacher: &mut PrefixCacheManagerV2,
900 disable_eos_stop: bool,
901 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
902 ) -> Result<(), candle_core::Error> {
903 sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
904 }
905 fn category(&self) -> ModelCategory {
906 ModelCategory::Vision {
907 prefixer: self.prefixer.clone(),
908 }
909 }
910}
911
912impl AnyMoePipelineMixin for VisionPipeline {
913 fn amoe_finish_training(&mut self, gate_model_id: Option<String>) -> candle_core::Result<()> {
914 self.model.finish_training(gate_model_id)
915 }
916 fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
917 self.model.get_vars()
918 }
919 fn amoe_base_model_trainable_params(&self) -> usize {
920 self.model.trainable_params()
921 }
922 fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
923 self.model.take_cached_gating_outputs()
924 }
925 fn amoe_create_layers(
926 &mut self,
927 model_ids: Vec<String>,
928 token: &TokenSource,
929 revision: Option<String>,
930 match_regex: &str,
931 config: crate::amoe::AnyMoeConfig,
932 dtype: candle_core::DType,
933 dev: &Device,
934 (prefix, mlp): (String, String),
935 layers: Vec<usize>,
936 expert_type: AnyMoeExpertType,
937 silent: bool,
938 gate_model_id: Option<String>,
939 ) -> candle_core::Result<()> {
940 let mut vbs = Vec::new();
941 let regex = Regex::new(match_regex).map_err(candle_core::Error::msg)?;
943 for model_id in model_ids {
944 let model_id_str = &model_id;
945 let model_id = Path::new(&model_id);
946
947 let api = {
948 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
949 let mut api = ApiBuilder::from_cache(cache)
950 .with_progress(!silent)
951 .with_token(get_token(token).map_err(candle_core::Error::msg)?);
952 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
953 api = api.with_cache_dir(x.into());
954 }
955 api.build().map_err(candle_core::Error::msg)?
956 };
957 let revision = revision.clone().unwrap_or("main".to_string());
958 let api = api.repo(Repo::with_revision(
959 model_id_str.clone(),
960 RepoType::Model,
961 revision.clone(),
962 ));
963
964 let mut filenames = vec![];
965 for rfilename in
966 api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
967 {
968 filenames.push(api_get_file!(api, &rfilename, model_id));
969 }
970
971 let regex = regex.clone();
972 let match_regex_clone = match_regex.to_string();
973 let layers_clone = layers.clone();
974 let vb = from_mmaped_safetensors(
975 filenames,
976 vec![],
977 Some(dtype),
978 dev,
979 vec![None],
980 silent,
981 None,
982 move |key| {
983 if regex.is_match(&key) {
984 let last_layer_idx = key.find(&match_regex_clone).unwrap() - 1;
987 let first_layer_idx = key[..last_layer_idx].rfind('.').unwrap();
988 let layer_n = key[first_layer_idx + 1..last_layer_idx]
989 .parse::<usize>()
990 .unwrap();
991 layers_clone.contains(&layer_n) || layers_clone.is_empty()
992 } else {
993 false
994 }
995 },
996 Arc::new(|_| DeviceForLoadTensor::Base),
997 )?;
998 vbs.push(vb);
999 }
1000
1001 let gate_vb = if let Some(gate_model_id) = gate_model_id {
1002 let model_id_str = &gate_model_id;
1003 let model_id = Path::new(&gate_model_id);
1004
1005 let api = {
1006 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1007 let mut api = ApiBuilder::from_cache(cache)
1008 .with_progress(!silent)
1009 .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1010 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1011 api = api.with_cache_dir(x.into());
1012 }
1013 api.build().map_err(candle_core::Error::msg)?
1014 };
1015 let revision = revision.clone().unwrap_or("main".to_string());
1016 let api = api.repo(Repo::with_revision(
1017 model_id_str.clone(),
1018 RepoType::Model,
1019 revision.clone(),
1020 ));
1021
1022 let mut gate_filenames = vec![];
1023 for rfilename in
1024 api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1025 {
1026 gate_filenames.push(api_get_file!(api, &rfilename, model_id));
1027 }
1028 assert_eq!(
1029 gate_filenames.len(),
1030 1,
1031 "Gate model ID must contain only one .safetensors file"
1032 );
1033
1034 let vb = from_mmaped_safetensors(
1035 gate_filenames.clone(),
1036 vec![],
1037 Some(dtype),
1038 dev,
1039 vec![None],
1040 silent,
1041 None,
1042 |_| true,
1043 Arc::new(|_| DeviceForLoadTensor::Base),
1044 )?;
1045 info!(
1046 "Loaded gating layers from `{}`",
1047 gate_filenames[0].display()
1048 );
1049 Some(vb)
1050 } else {
1051 None
1052 };
1053
1054 self.model
1055 .create_anymoe_layers(vbs, config, (prefix, mlp), layers, expert_type, gate_vb)
1056 }
1057 fn amoe_supported(&self) -> bool {
1058 self.model.amoe_supported()
1059 }
1060}